cs7324 Lab 5 - Wide and Deep Networks¶
Chip Henderson - 48996654¶
For this lab I am changing my dataset to the Breast Cancer Gene Expression Profiles (METABRIC) dataset (https://www.kaggle.com/datasets/raghadalharbi/breast-cancer-gene-expression-profiles-metabric).
I wanted a dataset with more feature data for this lab. Additionally, having a history of breast cancer in my family I was curious about the dataset which makes the lab more interesting. I'll be upfront about the fact that my lab is somewhat morbid. I'll be predicting whether a patient is likely to live or die based on the characteristics of their cancer. However, based on a personal family experience, a model like this may have helped drive a more realistic discussion on treatment plan. That aside, thanks to my wife who is an Oncology nurse practicioner and helped me understand some of these terms.
1. Preparation¶
Preprocessing and Class Variable Definition¶
import pandas as pd
bc_df = pd.read_csv(r'c:\users\chip\source\repos\cs7324_code\Lab 5\METABRIC_RNA_Mutation.csv', sep=',')
bc_df.shape
C:\Users\Chip\AppData\Local\Temp\ipykernel_17188\103049857.py:3: DtypeWarning: Columns (678,688,690,692) have mixed types. Specify dtype option on import or set low_memory=False. bc_df = pd.read_csv(r'c:\users\chip\source\repos\cs7324_code\Lab 5\METABRIC_RNA_Mutation.csv', sep=',')
(1904, 693)
# Source: https://stackoverflow.com/questions/34537048/how-to-count-nan-values-in-a-pandas-dataframe
bc_df.info(verbose=True, show_counts=True)
<class 'pandas.core.frame.DataFrame'> RangeIndex: 1904 entries, 0 to 1903 Data columns (total 693 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 patient_id 1904 non-null int64 1 age_at_diagnosis 1904 non-null float64 2 type_of_breast_surgery 1882 non-null object 3 cancer_type 1904 non-null object 4 cancer_type_detailed 1889 non-null object 5 cellularity 1850 non-null object 6 chemotherapy 1904 non-null int64 7 pam50_+_claudin-low_subtype 1904 non-null object 8 cohort 1904 non-null float64 9 er_status_measured_by_ihc 1874 non-null object 10 er_status 1904 non-null object 11 neoplasm_histologic_grade 1832 non-null float64 12 her2_status_measured_by_snp6 1904 non-null object 13 her2_status 1904 non-null object 14 tumor_other_histologic_subtype 1889 non-null object 15 hormone_therapy 1904 non-null int64 16 inferred_menopausal_state 1904 non-null object 17 integrative_cluster 1904 non-null object 18 primary_tumor_laterality 1798 non-null object 19 lymph_nodes_examined_positive 1904 non-null float64 20 mutation_count 1859 non-null float64 21 nottingham_prognostic_index 1904 non-null float64 22 oncotree_code 1889 non-null object 23 overall_survival_months 1904 non-null float64 24 overall_survival 1904 non-null int64 25 pr_status 1904 non-null object 26 radio_therapy 1904 non-null int64 27 3-gene_classifier_subtype 1700 non-null object 28 tumor_size 1884 non-null float64 29 tumor_stage 1403 non-null float64 30 death_from_cancer 1903 non-null object 31 brca1 1904 non-null float64 32 brca2 1904 non-null float64 33 palb2 1904 non-null float64 34 pten 1904 non-null float64 35 tp53 1904 non-null float64 36 atm 1904 non-null float64 37 cdh1 1904 non-null float64 38 chek2 1904 non-null float64 39 nbn 1904 non-null float64 40 nf1 1904 non-null float64 41 stk11 1904 non-null float64 42 bard1 1904 non-null float64 43 mlh1 1904 non-null float64 44 msh2 1904 non-null float64 45 msh6 1904 non-null float64 46 pms2 1904 non-null float64 47 epcam 1904 non-null float64 48 rad51c 1904 non-null float64 49 rad51d 1904 non-null float64 50 rad50 1904 non-null float64 51 rb1 1904 non-null float64 52 rbl1 1904 non-null float64 53 rbl2 1904 non-null float64 54 ccna1 1904 non-null float64 55 ccnb1 1904 non-null float64 56 cdk1 1904 non-null float64 57 ccne1 1904 non-null float64 58 cdk2 1904 non-null float64 59 cdc25a 1904 non-null float64 60 ccnd1 1904 non-null float64 61 cdk4 1904 non-null float64 62 cdk6 1904 non-null float64 63 ccnd2 1904 non-null float64 64 cdkn2a 1904 non-null float64 65 cdkn2b 1904 non-null float64 66 myc 1904 non-null float64 67 cdkn1a 1904 non-null float64 68 cdkn1b 1904 non-null float64 69 e2f1 1904 non-null float64 70 e2f2 1904 non-null float64 71 e2f3 1904 non-null float64 72 e2f4 1904 non-null float64 73 e2f5 1904 non-null float64 74 e2f6 1904 non-null float64 75 e2f7 1904 non-null float64 76 e2f8 1904 non-null float64 77 src 1904 non-null float64 78 jak1 1904 non-null float64 79 jak2 1904 non-null float64 80 stat1 1904 non-null float64 81 stat2 1904 non-null float64 82 stat3 1904 non-null float64 83 stat5a 1904 non-null float64 84 stat5b 1904 non-null float64 85 mdm2 1904 non-null float64 86 tp53bp1 1904 non-null float64 87 adam10 1904 non-null float64 88 adam17 1904 non-null float64 89 aph1a 1904 non-null float64 90 aph1b 1904 non-null float64 91 arrdc1 1904 non-null float64 92 cir1 1904 non-null float64 93 ctbp1 1904 non-null float64 94 ctbp2 1904 non-null float64 95 cul1 1904 non-null float64 96 dll1 1904 non-null float64 97 dll3 1904 non-null float64 98 dll4 1904 non-null float64 99 dtx1 1904 non-null float64 100 dtx2 1904 non-null float64 101 dtx3 1904 non-null float64 102 dtx4 1904 non-null float64 103 ep300 1904 non-null float64 104 fbxw7 1904 non-null float64 105 hdac1 1904 non-null float64 106 hdac2 1904 non-null float64 107 hes1 1904 non-null float64 108 hes5 1904 non-null float64 109 heyl 1904 non-null float64 110 itch 1904 non-null float64 111 jag1 1904 non-null float64 112 jag2 1904 non-null float64 113 kdm5a 1904 non-null float64 114 lfng 1904 non-null float64 115 maml1 1904 non-null float64 116 maml2 1904 non-null float64 117 maml3 1904 non-null float64 118 ncor2 1904 non-null float64 119 ncstn 1904 non-null float64 120 notch1 1904 non-null float64 121 notch2 1904 non-null float64 122 notch3 1904 non-null float64 123 nrarp 1904 non-null float64 124 numb 1904 non-null float64 125 numbl 1904 non-null float64 126 psen1 1904 non-null float64 127 psen2 1904 non-null float64 128 psenen 1904 non-null float64 129 rbpj 1904 non-null float64 130 rbpjl 1904 non-null float64 131 rfng 1904 non-null float64 132 snw1 1904 non-null float64 133 spen 1904 non-null float64 134 hes2 1904 non-null float64 135 hes4 1904 non-null float64 136 hes7 1904 non-null float64 137 hey1 1904 non-null float64 138 hey2 1904 non-null float64 139 acvr1 1904 non-null float64 140 acvr1b 1904 non-null float64 141 acvr1c 1904 non-null float64 142 acvr2a 1904 non-null float64 143 acvr2b 1904 non-null float64 144 acvrl1 1904 non-null float64 145 akt1 1904 non-null float64 146 akt1s1 1904 non-null float64 147 akt2 1904 non-null float64 148 apaf1 1904 non-null float64 149 arl11 1904 non-null float64 150 atr 1904 non-null float64 151 aurka 1904 non-null float64 152 bad 1904 non-null float64 153 bcl2 1904 non-null float64 154 bcl2l1 1904 non-null float64 155 bmp10 1904 non-null float64 156 bmp15 1904 non-null float64 157 bmp2 1904 non-null float64 158 bmp3 1904 non-null float64 159 bmp4 1904 non-null float64 160 bmp5 1904 non-null float64 161 bmp6 1904 non-null float64 162 bmp7 1904 non-null float64 163 bmpr1a 1904 non-null float64 164 bmpr1b 1904 non-null float64 165 bmpr2 1904 non-null float64 166 braf 1904 non-null float64 167 casp10 1904 non-null float64 168 casp3 1904 non-null float64 169 casp6 1904 non-null float64 170 casp7 1904 non-null float64 171 casp8 1904 non-null float64 172 casp9 1904 non-null float64 173 chek1 1904 non-null float64 174 csf1 1904 non-null float64 175 csf1r 1904 non-null float64 176 cxcl8 1904 non-null float64 177 cxcr1 1904 non-null float64 178 cxcr2 1904 non-null float64 179 dab2 1904 non-null float64 180 diras3 1904 non-null float64 181 dlec1 1904 non-null float64 182 dph1 1904 non-null float64 183 egfr 1904 non-null float64 184 eif4e 1904 non-null float64 185 eif4ebp1 1904 non-null float64 186 eif5a2 1904 non-null float64 187 erbb2 1904 non-null float64 188 erbb3 1904 non-null float64 189 erbb4 1904 non-null float64 190 fas 1904 non-null float64 191 fgf1 1904 non-null float64 192 fgfr1 1904 non-null float64 193 folr1 1904 non-null float64 194 folr2 1904 non-null float64 195 folr3 1904 non-null float64 196 foxo1 1904 non-null float64 197 foxo3 1904 non-null float64 198 gdf11 1904 non-null float64 199 gdf2 1904 non-null float64 200 gsk3b 1904 non-null float64 201 hif1a 1904 non-null float64 202 hla-g 1904 non-null float64 203 hras 1904 non-null float64 204 igf1 1904 non-null float64 205 igf1r 1904 non-null float64 206 inha 1904 non-null float64 207 inhba 1904 non-null float64 208 inhbc 1904 non-null float64 209 itgav 1904 non-null float64 210 itgb3 1904 non-null float64 211 izumo1r 1904 non-null float64 212 kdr 1904 non-null float64 213 kit 1904 non-null float64 214 kras 1904 non-null float64 215 map2k1 1904 non-null float64 216 map2k2 1904 non-null float64 217 map2k3 1904 non-null float64 218 map2k4 1904 non-null float64 219 map2k5 1904 non-null float64 220 map3k1 1904 non-null float64 221 map3k3 1904 non-null float64 222 map3k4 1904 non-null float64 223 map3k5 1904 non-null float64 224 mapk1 1904 non-null float64 225 mapk12 1904 non-null float64 226 mapk14 1904 non-null float64 227 mapk3 1904 non-null float64 228 mapk4 1904 non-null float64 229 mapk6 1904 non-null float64 230 mapk7 1904 non-null float64 231 mapk8 1904 non-null float64 232 mapk9 1904 non-null float64 233 mdc1 1904 non-null float64 234 mlst8 1904 non-null float64 235 mmp1 1904 non-null float64 236 mmp10 1904 non-null float64 237 mmp11 1904 non-null float64 238 mmp12 1904 non-null float64 239 mmp13 1904 non-null float64 240 mmp14 1904 non-null float64 241 mmp15 1904 non-null float64 242 mmp16 1904 non-null float64 243 mmp17 1904 non-null float64 244 mmp19 1904 non-null float64 245 mmp2 1904 non-null float64 246 mmp21 1904 non-null float64 247 mmp23b 1904 non-null float64 248 mmp24 1904 non-null float64 249 mmp25 1904 non-null float64 250 mmp26 1904 non-null float64 251 mmp27 1904 non-null float64 252 mmp28 1904 non-null float64 253 mmp3 1904 non-null float64 254 mmp7 1904 non-null float64 255 mmp9 1904 non-null float64 256 mtor 1904 non-null float64 257 nfkb1 1904 non-null float64 258 nfkb2 1904 non-null float64 259 opcml 1904 non-null float64 260 pdgfa 1904 non-null float64 261 pdgfb 1904 non-null float64 262 pdgfra 1904 non-null float64 263 pdgfrb 1904 non-null float64 264 pdpk1 1904 non-null float64 265 peg3 1904 non-null float64 266 pik3ca 1904 non-null float64 267 pik3r1 1904 non-null float64 268 pik3r2 1904 non-null float64 269 plagl1 1904 non-null float64 270 ptk2 1904 non-null float64 271 rab25 1904 non-null float64 272 rad51 1904 non-null float64 273 raf1 1904 non-null float64 274 rassf1 1904 non-null float64 275 rheb 1904 non-null float64 276 rictor 1904 non-null float64 277 rps6 1904 non-null float64 278 rps6ka1 1904 non-null float64 279 rps6ka2 1904 non-null float64 280 rps6kb1 1904 non-null float64 281 rps6kb2 1904 non-null float64 282 rptor 1904 non-null float64 283 slc19a1 1904 non-null float64 284 smad1 1904 non-null float64 285 smad2 1904 non-null float64 286 smad3 1904 non-null float64 287 smad4 1904 non-null float64 288 smad5 1904 non-null float64 289 smad6 1904 non-null float64 290 smad7 1904 non-null float64 291 smad9 1904 non-null float64 292 sptbn1 1904 non-null float64 293 terc 1904 non-null float64 294 tert 1904 non-null float64 295 tgfb1 1904 non-null float64 296 tgfb2 1904 non-null float64 297 tgfb3 1904 non-null float64 298 tgfbr1 1904 non-null float64 299 tgfbr2 1904 non-null float64 300 tgfbr3 1904 non-null float64 301 tsc1 1904 non-null float64 302 tsc2 1904 non-null float64 303 vegfa 1904 non-null float64 304 vegfb 1904 non-null float64 305 wfdc2 1904 non-null float64 306 wwox 1904 non-null float64 307 zfyve9 1904 non-null float64 308 arid1a 1904 non-null float64 309 arid1b 1904 non-null float64 310 cbfb 1904 non-null float64 311 gata3 1904 non-null float64 312 kmt2c 1904 non-null float64 313 kmt2d 1904 non-null float64 314 myh9 1904 non-null float64 315 ncor1 1904 non-null float64 316 pde4dip 1904 non-null float64 317 ptprd 1904 non-null float64 318 ros1 1904 non-null float64 319 runx1 1904 non-null float64 320 tbx3 1904 non-null float64 321 abcb1 1904 non-null float64 322 abcb11 1904 non-null float64 323 abcc1 1904 non-null float64 324 abcc10 1904 non-null float64 325 bbc3 1904 non-null float64 326 bmf 1904 non-null float64 327 cyp2c8 1904 non-null float64 328 cyp3a4 1904 non-null float64 329 fgf2 1904 non-null float64 330 fn1 1904 non-null float64 331 map2 1904 non-null float64 332 map4 1904 non-null float64 333 mapt 1904 non-null float64 334 nr1i2 1904 non-null float64 335 slco1b3 1904 non-null float64 336 tubb1 1904 non-null float64 337 tubb4a 1904 non-null float64 338 tubb4b 1904 non-null float64 339 twist1 1904 non-null float64 340 adgra2 1904 non-null float64 341 afdn 1904 non-null float64 342 aff2 1904 non-null float64 343 agmo 1904 non-null float64 344 agtr2 1904 non-null float64 345 ahnak 1904 non-null float64 346 ahnak2 1904 non-null float64 347 akap9 1904 non-null float64 348 alk 1904 non-null float64 349 apc 1904 non-null float64 350 arid2 1904 non-null float64 351 arid5b 1904 non-null float64 352 asxl1 1904 non-null float64 353 asxl2 1904 non-null float64 354 bap1 1904 non-null float64 355 bcas3 1904 non-null float64 356 birc6 1904 non-null float64 357 cacna2d3 1904 non-null float64 358 ccnd3 1904 non-null float64 359 chd1 1904 non-null float64 360 clk3 1904 non-null float64 361 clrn2 1904 non-null float64 362 col12a1 1904 non-null float64 363 col22a1 1904 non-null float64 364 col6a3 1904 non-null float64 365 ctcf 1904 non-null float64 366 ctnna1 1904 non-null float64 367 ctnna3 1904 non-null float64 368 dnah11 1904 non-null float64 369 dnah2 1904 non-null float64 370 dnah5 1904 non-null float64 371 dtwd2 1904 non-null float64 372 fam20c 1904 non-null float64 373 fanca 1904 non-null float64 374 fancd2 1904 non-null float64 375 flt3 1904 non-null float64 376 foxp1 1904 non-null float64 377 frmd3 1904 non-null float64 378 gh1 1904 non-null float64 379 gldc 1904 non-null float64 380 gpr32 1904 non-null float64 381 gps2 1904 non-null float64 382 hdac9 1904 non-null float64 383 herc2 1904 non-null float64 384 hist1h2bc 1904 non-null float64 385 kdm3a 1904 non-null float64 386 kdm6a 1904 non-null float64 387 klrg1 1904 non-null float64 388 l1cam 1904 non-null float64 389 lama2 1904 non-null float64 390 lamb3 1904 non-null float64 391 large1 1904 non-null float64 392 ldlrap1 1904 non-null float64 393 lifr 1904 non-null float64 394 lipi 1904 non-null float64 395 magea8 1904 non-null float64 396 map3k10 1904 non-null float64 397 map3k13 1904 non-null float64 398 men1 1904 non-null float64 399 mtap 1904 non-null float64 400 muc16 1904 non-null float64 401 myo1a 1904 non-null float64 402 myo3a 1904 non-null float64 403 ncoa3 1904 non-null float64 404 nek1 1904 non-null float64 405 nf2 1904 non-null float64 406 npnt 1904 non-null float64 407 nr2f1 1904 non-null float64 408 nr3c1 1904 non-null float64 409 nras 1904 non-null float64 410 nrg3 1904 non-null float64 411 nt5e 1904 non-null float64 412 or6a2 1904 non-null float64 413 palld 1904 non-null float64 414 pbrm1 1904 non-null float64 415 ppp2cb 1904 non-null float64 416 ppp2r2a 1904 non-null float64 417 prkacg 1904 non-null float64 418 prkce 1904 non-null float64 419 prkcq 1904 non-null float64 420 prkcz 1904 non-null float64 421 prkg1 1904 non-null float64 422 prps2 1904 non-null float64 423 prr16 1904 non-null float64 424 ptpn22 1904 non-null float64 425 ptprm 1904 non-null float64 426 rasgef1b 1904 non-null float64 427 rpgr 1904 non-null float64 428 ryr2 1904 non-null float64 429 sbno1 1904 non-null float64 430 setd1a 1904 non-null float64 431 setd2 1904 non-null float64 432 setdb1 1904 non-null float64 433 sf3b1 1904 non-null float64 434 sgcd 1904 non-null float64 435 shank2 1904 non-null float64 436 siah1 1904 non-null float64 437 sik1 1904 non-null float64 438 sik2 1904 non-null float64 439 smarcb1 1904 non-null float64 440 smarcc1 1904 non-null float64 441 smarcc2 1904 non-null float64 442 smarcd1 1904 non-null float64 443 spaca1 1904 non-null float64 444 stab2 1904 non-null float64 445 stmn2 1904 non-null float64 446 syne1 1904 non-null float64 447 taf1 1904 non-null float64 448 taf4b 1904 non-null float64 449 tbl1xr1 1904 non-null float64 450 tg 1904 non-null float64 451 thada 1904 non-null float64 452 thsd7a 1904 non-null float64 453 ttyh1 1904 non-null float64 454 ubr5 1904 non-null float64 455 ush2a 1904 non-null float64 456 usp9x 1904 non-null float64 457 utrn 1904 non-null float64 458 zfp36l1 1904 non-null float64 459 ackr3 1904 non-null float64 460 akr1c1 1904 non-null float64 461 akr1c2 1904 non-null float64 462 akr1c3 1904 non-null float64 463 akr1c4 1904 non-null float64 464 akt3 1904 non-null float64 465 ar 1904 non-null float64 466 bche 1904 non-null float64 467 cdk8 1904 non-null float64 468 cdkn2c 1904 non-null float64 469 cyb5a 1904 non-null float64 470 cyp11a1 1904 non-null float64 471 cyp11b2 1904 non-null float64 472 cyp17a1 1904 non-null float64 473 cyp19a1 1904 non-null float64 474 cyp21a2 1904 non-null float64 475 cyp3a43 1904 non-null float64 476 cyp3a5 1904 non-null float64 477 cyp3a7 1904 non-null float64 478 ddc 1904 non-null float64 479 hes6 1904 non-null float64 480 hsd17b1 1904 non-null float64 481 hsd17b10 1904 non-null float64 482 hsd17b11 1904 non-null float64 483 hsd17b12 1904 non-null float64 484 hsd17b13 1904 non-null float64 485 hsd17b14 1904 non-null float64 486 hsd17b2 1904 non-null float64 487 hsd17b3 1904 non-null float64 488 hsd17b4 1904 non-null float64 489 hsd17b6 1904 non-null float64 490 hsd17b7 1904 non-null float64 491 hsd17b8 1904 non-null float64 492 hsd3b1 1904 non-null float64 493 hsd3b2 1904 non-null float64 494 hsd3b7 1904 non-null float64 495 mecom 1904 non-null float64 496 met 1904 non-null float64 497 ncoa2 1904 non-null float64 498 nrip1 1904 non-null float64 499 pik3r3 1904 non-null float64 500 prkci 1904 non-null float64 501 prkd1 1904 non-null float64 502 ran 1904 non-null float64 503 rdh5 1904 non-null float64 504 sdc4 1904 non-null float64 505 serpini1 1904 non-null float64 506 shbg 1904 non-null float64 507 slc29a1 1904 non-null float64 508 sox9 1904 non-null float64 509 spry2 1904 non-null float64 510 srd5a1 1904 non-null float64 511 srd5a2 1904 non-null float64 512 srd5a3 1904 non-null float64 513 st7 1904 non-null float64 514 star 1904 non-null float64 515 tnk2 1904 non-null float64 516 tulp4 1904 non-null float64 517 ugt2b15 1904 non-null float64 518 ugt2b17 1904 non-null float64 519 ugt2b7 1904 non-null float64 520 pik3ca_mut 1904 non-null object 521 tp53_mut 1904 non-null object 522 muc16_mut 1904 non-null object 523 ahnak2_mut 1904 non-null object 524 kmt2c_mut 1904 non-null object 525 syne1_mut 1904 non-null object 526 gata3_mut 1904 non-null object 527 map3k1_mut 1904 non-null object 528 ahnak_mut 1904 non-null object 529 dnah11_mut 1904 non-null object 530 cdh1_mut 1904 non-null object 531 dnah2_mut 1904 non-null object 532 kmt2d_mut 1904 non-null object 533 ush2a_mut 1904 non-null object 534 ryr2_mut 1904 non-null object 535 dnah5_mut 1904 non-null object 536 herc2_mut 1904 non-null object 537 pde4dip_mut 1904 non-null object 538 akap9_mut 1904 non-null object 539 tg_mut 1904 non-null object 540 birc6_mut 1904 non-null object 541 utrn_mut 1904 non-null object 542 tbx3_mut 1904 non-null object 543 col6a3_mut 1904 non-null object 544 arid1a_mut 1904 non-null object 545 lama2_mut 1904 non-null object 546 notch1_mut 1904 non-null object 547 cbfb_mut 1904 non-null object 548 ncor2_mut 1904 non-null object 549 col12a1_mut 1904 non-null object 550 col22a1_mut 1904 non-null object 551 pten_mut 1904 non-null object 552 akt1_mut 1904 non-null object 553 atr_mut 1904 non-null object 554 thada_mut 1904 non-null object 555 ncor1_mut 1904 non-null object 556 stab2_mut 1904 non-null object 557 myh9_mut 1904 non-null object 558 runx1_mut 1904 non-null object 559 nf1_mut 1904 non-null object 560 map2k4_mut 1904 non-null object 561 ros1_mut 1904 non-null object 562 lamb3_mut 1904 non-null object 563 arid1b_mut 1904 non-null object 564 erbb2_mut 1904 non-null object 565 sf3b1_mut 1904 non-null object 566 shank2_mut 1904 non-null object 567 ep300_mut 1904 non-null object 568 ptprd_mut 1904 non-null object 569 usp9x_mut 1904 non-null object 570 setd2_mut 1904 non-null object 571 setd1a_mut 1904 non-null object 572 thsd7a_mut 1904 non-null object 573 afdn_mut 1904 non-null object 574 erbb3_mut 1904 non-null object 575 rb1_mut 1904 non-null object 576 myo1a_mut 1904 non-null object 577 alk_mut 1904 non-null object 578 fanca_mut 1904 non-null object 579 adgra2_mut 1904 non-null object 580 ubr5_mut 1904 non-null object 581 pik3r1_mut 1904 non-null object 582 myo3a_mut 1904 non-null object 583 asxl2_mut 1904 non-null object 584 apc_mut 1904 non-null object 585 ctcf_mut 1904 non-null object 586 asxl1_mut 1904 non-null object 587 fancd2_mut 1904 non-null object 588 taf1_mut 1904 non-null object 589 kdm6a_mut 1904 non-null object 590 ctnna3_mut 1904 non-null object 591 brca1_mut 1904 non-null object 592 ptprm_mut 1904 non-null object 593 foxo3_mut 1904 non-null object 594 usp28_mut 1904 non-null object 595 gldc_mut 1904 non-null object 596 brca2_mut 1904 non-null object 597 cacna2d3_mut 1904 non-null object 598 arid2_mut 1904 non-null object 599 aff2_mut 1904 non-null object 600 lifr_mut 1904 non-null object 601 sbno1_mut 1904 non-null object 602 kdm3a_mut 1904 non-null object 603 ncoa3_mut 1904 non-null object 604 bap1_mut 1904 non-null object 605 l1cam_mut 1904 non-null object 606 pbrm1_mut 1904 non-null object 607 chd1_mut 1904 non-null object 608 jak1_mut 1904 non-null object 609 setdb1_mut 1904 non-null object 610 fam20c_mut 1904 non-null object 611 arid5b_mut 1904 non-null object 612 egfr_mut 1904 non-null object 613 map3k10_mut 1904 non-null object 614 smarcc2_mut 1904 non-null object 615 erbb4_mut 1904 non-null object 616 npnt_mut 1904 non-null object 617 nek1_mut 1904 non-null object 618 agmo_mut 1904 non-null object 619 zfp36l1_mut 1904 non-null object 620 smad4_mut 1904 non-null object 621 sik1_mut 1904 non-null object 622 casp8_mut 1904 non-null object 623 prkcq_mut 1904 non-null object 624 smarcc1_mut 1904 non-null object 625 palld_mut 1904 non-null object 626 dcaf4l2_mut 1904 non-null object 627 bcas3_mut 1904 non-null object 628 cdkn1b_mut 1904 non-null object 629 gps2_mut 1904 non-null object 630 men1_mut 1904 non-null object 631 stk11_mut 1904 non-null object 632 sik2_mut 1904 non-null object 633 ptpn22_mut 1904 non-null object 634 brip1_mut 1904 non-null object 635 flt3_mut 1904 non-null object 636 nrg3_mut 1904 non-null object 637 fbxw7_mut 1904 non-null object 638 ttyh1_mut 1904 non-null object 639 taf4b_mut 1904 non-null object 640 or6a2_mut 1904 non-null object 641 map3k13_mut 1904 non-null object 642 hdac9_mut 1904 non-null object 643 prkacg_mut 1904 non-null object 644 rpgr_mut 1904 non-null object 645 large1_mut 1904 non-null object 646 foxp1_mut 1904 non-null object 647 clk3_mut 1904 non-null object 648 prkcz_mut 1904 non-null object 649 lipi_mut 1904 non-null object 650 ppp2r2a_mut 1904 non-null object 651 prkce_mut 1904 non-null object 652 gh1_mut 1904 non-null object 653 gpr32_mut 1904 non-null object 654 kras_mut 1904 non-null object 655 nf2_mut 1904 non-null object 656 chek2_mut 1904 non-null object 657 ldlrap1_mut 1904 non-null object 658 clrn2_mut 1904 non-null object 659 acvrl1_mut 1904 non-null object 660 agtr2_mut 1904 non-null object 661 cdkn2a_mut 1904 non-null object 662 ctnna1_mut 1904 non-null object 663 magea8_mut 1904 non-null object 664 prr16_mut 1904 non-null object 665 dtwd2_mut 1904 non-null object 666 akt2_mut 1904 non-null object 667 braf_mut 1904 non-null object 668 foxo1_mut 1904 non-null object 669 nt5e_mut 1904 non-null object 670 ccnd3_mut 1904 non-null object 671 nr3c1_mut 1904 non-null object 672 prkg1_mut 1904 non-null object 673 tbl1xr1_mut 1904 non-null object 674 frmd3_mut 1904 non-null object 675 smad2_mut 1904 non-null object 676 sgcd_mut 1904 non-null object 677 spaca1_mut 1904 non-null object 678 rasgef1b_mut 1904 non-null object 679 hist1h2bc_mut 1904 non-null object 680 nr2f1_mut 1904 non-null object 681 klrg1_mut 1904 non-null object 682 mbl2_mut 1904 non-null object 683 mtap_mut 1904 non-null object 684 ppp2cb_mut 1904 non-null object 685 smarcd1_mut 1904 non-null object 686 nras_mut 1904 non-null object 687 ndfip1_mut 1904 non-null object 688 hras_mut 1904 non-null object 689 prps2_mut 1904 non-null object 690 smarcb1_mut 1904 non-null object 691 stmn2_mut 1904 non-null object 692 siah1_mut 1904 non-null object dtypes: float64(498), int64(5), object(190) memory usage: 10.1+ MB
bc_df.head()
| patient_id | age_at_diagnosis | type_of_breast_surgery | cancer_type | cancer_type_detailed | cellularity | chemotherapy | pam50_+_claudin-low_subtype | cohort | er_status_measured_by_ihc | ... | mtap_mut | ppp2cb_mut | smarcd1_mut | nras_mut | ndfip1_mut | hras_mut | prps2_mut | smarcb1_mut | stmn2_mut | siah1_mut | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0 | 75.65 | MASTECTOMY | Breast Cancer | Breast Invasive Ductal Carcinoma | NaN | 0 | claudin-low | 1.0 | Positve | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 1 | 2 | 43.19 | BREAST CONSERVING | Breast Cancer | Breast Invasive Ductal Carcinoma | High | 0 | LumA | 1.0 | Positve | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 2 | 5 | 48.87 | MASTECTOMY | Breast Cancer | Breast Invasive Ductal Carcinoma | High | 1 | LumB | 1.0 | Positve | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 3 | 6 | 47.68 | MASTECTOMY | Breast Cancer | Breast Mixed Ductal and Lobular Carcinoma | Moderate | 1 | LumB | 1.0 | Positve | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
| 4 | 8 | 76.97 | MASTECTOMY | Breast Cancer | Breast Mixed Ductal and Lobular Carcinoma | High | 1 | LumB | 1.0 | Positve | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 693 columns
bc_df.describe()
| patient_id | age_at_diagnosis | chemotherapy | cohort | neoplasm_histologic_grade | hormone_therapy | lymph_nodes_examined_positive | mutation_count | nottingham_prognostic_index | overall_survival_months | ... | srd5a1 | srd5a2 | srd5a3 | st7 | star | tnk2 | tulp4 | ugt2b15 | ugt2b17 | ugt2b7 | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| count | 1904.000000 | 1904.000000 | 1904.000000 | 1904.000000 | 1832.000000 | 1904.000000 | 1904.000000 | 1859.000000 | 1904.000000 | 1904.000000 | ... | 1.904000e+03 | 1.904000e+03 | 1.904000e+03 | 1.904000e+03 | 1904.000000 | 1.904000e+03 | 1.904000e+03 | 1.904000e+03 | 1904.000000 | 1.904000e+03 |
| mean | 3921.982143 | 61.087054 | 0.207983 | 2.643908 | 2.415939 | 0.616597 | 2.002101 | 5.697687 | 4.033019 | 125.121324 | ... | 4.726891e-07 | -3.676471e-07 | -9.453782e-07 | -1.050420e-07 | -0.000002 | 3.676471e-07 | 4.726891e-07 | 7.878151e-07 | 0.000000 | 3.731842e-18 |
| std | 2358.478332 | 12.978711 | 0.405971 | 1.228615 | 0.650612 | 0.486343 | 4.079993 | 4.058778 | 1.144492 | 76.334148 | ... | 1.000263e+00 | 1.000262e+00 | 1.000262e+00 | 1.000263e+00 | 1.000262 | 1.000264e+00 | 1.000262e+00 | 1.000263e+00 | 1.000262 | 1.000262e+00 |
| min | 0.000000 | 21.930000 | 0.000000 | 1.000000 | 1.000000 | 0.000000 | 0.000000 | 1.000000 | 1.000000 | 0.000000 | ... | -2.120800e+00 | -3.364800e+00 | -2.719400e+00 | -4.982700e+00 | -2.981700 | -3.833300e+00 | -3.609300e+00 | -1.166900e+00 | -2.112600 | -1.051600e+00 |
| 25% | 896.500000 | 51.375000 | 0.000000 | 1.000000 | 2.000000 | 0.000000 | 0.000000 | 3.000000 | 3.046000 | 60.825000 | ... | -6.188500e-01 | -6.104750e-01 | -6.741750e-01 | -6.136750e-01 | -0.632900 | -6.664750e-01 | -7.102000e-01 | -5.058250e-01 | -0.476200 | -7.260000e-01 |
| 50% | 4730.500000 | 61.770000 | 0.000000 | 3.000000 | 3.000000 | 1.000000 | 0.000000 | 5.000000 | 4.042000 | 115.616667 | ... | -2.456500e-01 | -4.690000e-02 | -1.422500e-01 | -5.175000e-02 | -0.026650 | 7.000000e-04 | -2.980000e-02 | -2.885500e-01 | -0.133400 | -4.248000e-01 |
| 75% | 5536.250000 | 70.592500 | 0.000000 | 3.000000 | 3.000000 | 1.000000 | 2.000000 | 7.000000 | 5.040250 | 184.716667 | ... | 3.306000e-01 | 5.144500e-01 | 5.146000e-01 | 5.787750e-01 | 0.590350 | 6.429000e-01 | 5.957250e-01 | 6.022500e-02 | 0.270375 | 4.284000e-01 |
| max | 7299.000000 | 96.290000 | 1.000000 | 5.000000 | 3.000000 | 1.000000 | 45.000000 | 80.000000 | 6.360000 | 355.200000 | ... | 6.534900e+00 | 1.027030e+01 | 6.329000e+00 | 4.571300e+00 | 12.742300 | 3.938800e+00 | 3.833400e+00 | 1.088490e+01 | 12.643900 | 3.284400e+00 |
8 rows × 503 columns
I'll go ahead and drop na values to help reduce the dataset to more usable data. I'm also dropping the following:
- 'cancer type' as this entire dataset is breast cancer related and all the values are the same
- 'cohort' because this is an assigned value and not a measured one that could help prediction
- 'overall_survival' because this is going to be represented by my classification groups and also would have provided misleading results if the person didn't die of disease
I'm also going to drop the genetic attributes (features 31 through 693) because of the following:
- They increase the size significantly
- The values aren't easily understandable
- They don't necessarily provide meaningful contribution to the classification objective to non-medical professionals
import copy
# bc_df_full = bc_df.copy() # for use later
bc_df = bc_df.dropna(axis=0)
# bc_df = bc_df.dropna(axis=1)
bc_df = bc_df.drop(bc_df.columns[31:693],axis=1)
features_to_drop = ['cancer_type', 'overall_survival', 'cohort']
bc_df = bc_df.drop(features_to_drop, axis=1)
bc_df.reset_index() # new
# rename column to remove + symbol to avoid any potential datatype issues
bc_df.rename(columns={'pam50_+_claudin-low_subtype':'pam50_plus_claudin-low_subtype'},inplace=True)
bc_df.shape
(1092, 28)
That removed roughly 900 instances of data and a few hundered columns which makes my data far easier to work with. To verify I have no more missing data I'll import missingno to visualize it.
# Referencing code from lecture and in-class examples
import matplotlib
import matplotlib.pyplot as plt
import warnings
warnings.simplefilter('ignore', DeprecationWarning)
%matplotlib inline
import missingno as mn
# As a departure from lecture code I'm using a bar chart,
# Matrix version gave me errors regarding 'grid_b' which I wasn't able to resolve
mn.bar(bc_df)
<Axes: >
No more missing data from my features so we're good to proceed.
# Determine the remaining datatpyes I'm working with
bc_df.info()
<class 'pandas.core.frame.DataFrame'> Index: 1092 entries, 1 to 1664 Data columns (total 28 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 patient_id 1092 non-null int64 1 age_at_diagnosis 1092 non-null float64 2 type_of_breast_surgery 1092 non-null object 3 cancer_type_detailed 1092 non-null object 4 cellularity 1092 non-null object 5 chemotherapy 1092 non-null int64 6 pam50_plus_claudin-low_subtype 1092 non-null object 7 er_status_measured_by_ihc 1092 non-null object 8 er_status 1092 non-null object 9 neoplasm_histologic_grade 1092 non-null float64 10 her2_status_measured_by_snp6 1092 non-null object 11 her2_status 1092 non-null object 12 tumor_other_histologic_subtype 1092 non-null object 13 hormone_therapy 1092 non-null int64 14 inferred_menopausal_state 1092 non-null object 15 integrative_cluster 1092 non-null object 16 primary_tumor_laterality 1092 non-null object 17 lymph_nodes_examined_positive 1092 non-null float64 18 mutation_count 1092 non-null float64 19 nottingham_prognostic_index 1092 non-null float64 20 oncotree_code 1092 non-null object 21 overall_survival_months 1092 non-null float64 22 pr_status 1092 non-null object 23 radio_therapy 1092 non-null int64 24 3-gene_classifier_subtype 1092 non-null object 25 tumor_size 1092 non-null float64 26 tumor_stage 1092 non-null float64 27 death_from_cancer 1092 non-null object dtypes: float64(8), int64(4), object(16) memory usage: 247.4+ KB
For my classification objective I want to know what the possible outcomes are
unique_survivability = list(enumerate(bc_df.death_from_cancer.unique()))
print(unique_survivability)
[(0, 'Living'), (1, 'Died of Disease'), (2, 'Died of Other Causes')]
Died of other causes may not be of value to me because if they didn't either survive or succumb to the disease the data won't provide accurate prediction. Another way to say 'Died of Other Causes' may be that they survived the disease. But, since we don't know how they died, cancer may have ended up being the cause of death given a long enough life span.
I'll start by seeing how many values are in this category then deciding what to do with it.
survivability_list = [x for x in bc_df['death_from_cancer'] if x == 'Died of Other Causes']
print(f'There are {len(survivability_list)} deaths related to other causes out of {bc_df.shape[0]} instances')
There are 238 deaths related to other causes out of 1092 instances
I'm going to drop these values even though it'll put my number of instances under 1k, smaller than desired. I'll review the class balances later and determine whether I should add some additional samples via oversampling.
bool_of_unrelated_deaths = (bc_df['death_from_cancer'] == 'Died of Other Causes')
idx_matching = bc_df[bool_of_unrelated_deaths].index
bc_df = bc_df.drop(idx_matching,axis=0)
print(bc_df.shape)
(854, 28)
I'd like to understand how well balanced by two classification groups are, as that will impact how well my model may perform. So I'll check that next.
import matplotlib
import matplotlib.pyplot as plt
outcomes = bc_df.groupby(['death_from_cancer'])
outcomes.count().plot(kind='pie',
y='patient_id',
autopct='%1.1f%%',
title = "Quantity of Each Outcome")
<Axes: title={'center': 'Quantity of Each Outcome'}, ylabel='patient_id'>
The pie chart is a simple and visually effective way to represent the balance in my two classes. The balance is fairly close but I'd like to have them as even as possible. So I'm going to oversample by repeating several values from 1 (died_of_cancer). This is an appropriate technique because what I'm wanting my model to learn are the characteristics of the cancer that make it more lethal, thereby providing more confirmational guidance for care providers to have open and candid conversation with their patients. If I were to simply recreate instances or impute and alter the result to balance the classes, this would be an inappropriate technique to balance.
There are 854 entries total, so if I want to have each value be equal the number of instances where the individual died needs to be increased by (.567-.433)*854 = ~114 instances.
# Add instances to balance the classes
bool_of_related_deaths = (bc_df['death_from_cancer'] == 'Died of Disease')
idx_matching_1 = bc_df[bool_of_related_deaths].index
bc_df_died = bc_df.loc[idx_matching_1]
bc_df_died = bc_df_died[:114] # only need a few
bc_df = pd.concat([bc_df, bc_df_died], ignore_index=True)
print(bc_df.shape)
(968, 28)
# Take another look at our pie chart to verify the classes are balanced.
outcomes = bc_df.groupby(['death_from_cancer'])
outcomes.count().plot(kind='pie',
y='patient_id',
autopct='%1.1f%%',
title="Quantity of Each Outcome After Oversampling")
<Axes: title={'center': 'Quantity of Each Outcome After Oversampling'}, ylabel='patient_id'>
Now that my outcomes are balanced I'm going to drop patient ID before moving on. This is because it is an assigned value with no use in prediction. I can use the instance index values if needed to refer to a particular instance. Also, I'll move the feature I intend to predict to be the last one in the dataframe as more of a data visualization intent than anything else.
# Drop patient id
bc_df = bc_df.drop('patient_id', axis=1)
# Move column 'death_from_cancer' to the end
bc_df = bc_df[[col for col in bc_df.columns if col != 'death_from_cancer'] + ['death_from_cancer']]
bc_df.info(verbose=True)
<class 'pandas.core.frame.DataFrame'> RangeIndex: 968 entries, 0 to 967 Data columns (total 27 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 age_at_diagnosis 968 non-null float64 1 type_of_breast_surgery 968 non-null object 2 cancer_type_detailed 968 non-null object 3 cellularity 968 non-null object 4 chemotherapy 968 non-null int64 5 pam50_plus_claudin-low_subtype 968 non-null object 6 er_status_measured_by_ihc 968 non-null object 7 er_status 968 non-null object 8 neoplasm_histologic_grade 968 non-null float64 9 her2_status_measured_by_snp6 968 non-null object 10 her2_status 968 non-null object 11 tumor_other_histologic_subtype 968 non-null object 12 hormone_therapy 968 non-null int64 13 inferred_menopausal_state 968 non-null object 14 integrative_cluster 968 non-null object 15 primary_tumor_laterality 968 non-null object 16 lymph_nodes_examined_positive 968 non-null float64 17 mutation_count 968 non-null float64 18 nottingham_prognostic_index 968 non-null float64 19 oncotree_code 968 non-null object 20 overall_survival_months 968 non-null float64 21 pr_status 968 non-null object 22 radio_therapy 968 non-null int64 23 3-gene_classifier_subtype 968 non-null object 24 tumor_size 968 non-null float64 25 tumor_stage 968 non-null float64 26 death_from_cancer 968 non-null object dtypes: float64(8), int64(3), object(16) memory usage: 204.3+ KB
Final Dataset Description¶
This final dataset is comprised of 968 breast cancer patients with 27 features related to their case. At this point, the data has not been either one-hot encoded or label-encoded. I'll do that as part of my FeatureSpace setup. The target data is an object datatype containing the outcome of whether or not the patient lived, making this model a binary classifier.
The outcome (target) values datasets are balanced with half of the instances relating to patients that lived, and half the instances relating to patients that died of the disease. I've removed any instances with missing values from the dataset and I have removed gene specific z-values which comprised a great deal of the original features but were not value-added for my purposes. The data is not yet scaled.
Cross-Product Feature Identification¶
To start, I'm interested in how many unique values I have in each feature:
# Source: for dictionary sorting https://stackoverflow.com/questions/64885734/how-to-sort-a-dictionary-in-descending-order-according-its-value
# unique_feature_count = [feature for feature in bc_df.columns if feature != 'death_from_cancer']
unique_feature_count = [feature for feature in bc_df.columns if bc_df[feature].dtype == object]
unique_dict = {}
for feature in unique_feature_count:
unique_vals = len(list(enumerate(bc_df[feature].unique())))
unique_dict[feature] = unique_vals
unique_dict_sorted = sorted(unique_dict.items(), key=lambda x:x[1], reverse=True)
for item in unique_dict_sorted:
print(f'There are {item[1]} unique values in {item[0]}')
There are 11 unique values in integrative_cluster There are 7 unique values in pam50_plus_claudin-low_subtype There are 7 unique values in tumor_other_histologic_subtype There are 5 unique values in cancer_type_detailed There are 5 unique values in oncotree_code There are 4 unique values in her2_status_measured_by_snp6 There are 4 unique values in 3-gene_classifier_subtype There are 3 unique values in cellularity There are 2 unique values in type_of_breast_surgery There are 2 unique values in er_status_measured_by_ihc There are 2 unique values in er_status There are 2 unique values in her2_status There are 2 unique values in inferred_menopausal_state There are 2 unique values in primary_tumor_laterality There are 2 unique values in pr_status There are 2 unique values in death_from_cancer
To determine which features to cross I'd like to understand which ones are correlated to eachother. To do that, I need to one hot encode some of my values. Because I'm going to encode later, I'll make this as a separate dataframe.
# One-hot encode other object values
# I'll write a loop for this since there are several
import copy
bc_df_encoded = bc_df.copy()
# limit to categorical features
features_to_encode = [label for label in bc_df_encoded.columns if bc_df_encoded.dtypes[label] == object]
# print(features_to_encode) # debug
for feature in features_to_encode:
tmp_df = pd.get_dummies(bc_df[feature],prefix=feature)
bc_df_encoded = pd.concat((bc_df_encoded,tmp_df),axis=1)
bc_df_encoded = bc_df_encoded.drop(feature, axis=1) # drop original column
# Let's pull some basic correlation data to see if that will help identify features to cross product
# Check correlation of each feature to 'death_from_cancer' first
features_to_correlate = [feature for feature in bc_df_encoded.columns
if feature != 'death_from_cancer'
and bc_df_encoded[feature].dtype == bool]
corr_to_outcome = [bc_df_encoded[feature].corr(bc_df_encoded['death_from_cancer_Died of Disease'])
for feature in features_to_correlate]
vars_to_use = []
for feature, value in zip(features_to_correlate, corr_to_outcome):
if value >= 0.1: # Limit just to values 0.2
print(f'The correlation of ', feature, ' to outcome is ', round(value,3))
vars_to_use.append(feature)
The correlation of type_of_breast_surgery_MASTECTOMY to outcome is 0.201 The correlation of pam50_plus_claudin-low_subtype_Her2 to outcome is 0.138 The correlation of pam50_plus_claudin-low_subtype_LumB to outcome is 0.118 The correlation of her2_status_measured_by_snp6_GAIN to outcome is 0.106 The correlation of her2_status_Positive to outcome is 0.137 The correlation of inferred_menopausal_state_Post to outcome is 0.1 The correlation of integrative_cluster_5 to outcome is 0.159 The correlation of pr_status_Negative to outcome is 0.11 The correlation of 3-gene_classifier_subtype_ER+/HER2- High Prolif to outcome is 0.14 The correlation of 3-gene_classifier_subtype_HER2+ to outcome is 0.12 The correlation of death_from_cancer_Died of Disease to outcome is 1.0
No single feature appears to be significantly correlated with my target values. However it would also be useful to understand how some of these features correlate to one-another. I'll put together a heatmap of some of these values to visualize that.
# Source: modified from lab_1 in-class lectures
import seaborn as sns
# plot the correlation matrix using a subset of features
cmap = sns.set(style="darkgrid") # one of the many styles to plot using
f, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(bc_df_encoded[vars_to_use].corr(), cmap=cmap, annot=True)
<Axes: >
Generally, I'm going to pick things that can be thought of as rules. As such, I'm looking for things that are related or appear to move as though they are related. Therefore I'm using correlation as a guiding principle. Second, I'll look for features that seem like they would interract.
Therefore, I'm going to cross the following:
- her2_status_measured_by_snp6 and her2_status. These two items are related in that both are used to classify the cancer subtype (Source: mayoclinic.org)
- 3-gene_classifier_subtype and integrative_cluster. "Integrative_clustering is a breast cancer classification of 10 different subgroups with distinctive molecular profiles and clininclal outcomes" (source: https://ascopubs.org/doi/abs/10.1200/JCO.2018.36.15_suppl.579)
- 3-gene_classifier, integrative_cluster, pam50_plus is another option to consider as all are forms of classification and should move together
- er_status and er_status_measured_by_ihc should also move together though I didn't put them in my correlation heatmap
- Another experiment to try would be to see if I can map chemotherapy, radio_therapy, and hormone_therapy to true/false instead of 0/1 because then I could cross the features (I tried to cross these as integer values which I quikcly learned you can't do)
Metric Selection and Reasoning¶
For this model I'll be using F-measure for assessing performance
The model will be designed as a binary classifier of the probable outcome of breast cancer based on the inputs. So the metrics of primary interest are true positives (the patient will likely die due to disesase), and false negatives (the patient died of disesase but it wasn't predicted).
I'll take a moment to discuss false positives as their impact can be viewed differently depending on perspective. A false positive for this model would be a prediction that the patient would die from disease but ends up being incorrect. This presents an obviously difficult situation for the patient due to the psychological impact of such a diagnosis. The counterbalance of this however, is that the patients outcome is much better. Despite this potential upside, the patient may make life altering decisions based on a diagnosis this model provides. Therefore I have to treat false positives with almost equal importance as true positives and false negatives.
To choose a metric that meets these requirements, I need to assess the measurements used in calculating the metrics. I would like to select a metric that emphasizes:
- True Positives
- False Negatives
- False Positives
Precision Uses False Positives and False Positives. It does not weigh false negatives. Since I believe false negatives will be important to my model's performance, this is not a viable option. Recall Uses True Positives and False Negatives. Again it addresses two of my areas of interest but not all three as False Positives are missing.
F-Measure combines precision and recall with equal balances between the two. This is an ideal measure to use for my purposes as it addresses the three areas of interest with equal weightings. I'll plan on using F-measure in assessing my model's performance.
Another option to consider would be using F_beta which allows me to decrease the impact of Recall on the calculation. However, knowing that recall uses True Positives and False Negatives which are both of high importance to me, I don't want to decrement the weighting of that particular metric in this circumstance.
Methods for Dividing and Testing the Data¶
For this model I'll be using K-fold for Test/Train Split
To assess which method for my testing and training data is best I need to understand some basic information about how I've set up my data. I purposefully oversampled my data for this lab such that the results of my binary output are an even 50/50 split. Additionally, having 968 instances and 27 features, the size of the dataset is in the small to medium range.
The options I have for selecting how to split my testing and trainig data include holdout, random subsampling, and KFold or Stratified KFold. Because my dataset is of the small to medium size range, holdout and random subsampling are likely not required. Random subsampling would be more applicable in a very large datset which makes training a model on all the data an inefficient process. Holdout is a viable option, however KFold provides a more thorough understanding of how my model is performing on the data.
K-fold will help ensure I have an evenly divided test data-set, and any trends that might appear in the data are mitigated. I'll also be using shuffle to further address this. Stratified K-fold would be appropriate if my dataset was unbalanced in its results. Even without oversampling I show above, this dataset was relatively even. Therefore I should be fine using K-fold without employing the stratified technique. If I want to have an alternative option I will consider using holdout as that is the next most viable selection.
# Setup my feature labels with appropriate variables needed later
import numpy as np
# create a tensorflow dataset, for ease of use later
batch_size = 44 # 44 divides evenly into my total instances, 968
# Map the classification groups to integers
# This could have been done in FeatureSpace, but doing here as a matter of preference
survivability_dict = {'Living':0, 'Died of Disease':1}
bc_df['death_from_cancer'] = bc_df['death_from_cancer'].map(survivability_dict)
categorical_headers = [label for label in bc_df.columns if bc_df.dtypes[label] == object]
int_headers = ['chemotherapy','radio_therapy','hormone_therapy']
numeric_headers = [label for label in bc_df.columns if bc_df.dtypes[label] == float] + int_headers
# Perform the test/train split of the data
# Source: https://stackoverflow.com/questions/45115964/separate-pandas-dataframe-using-sklearns-kfold
# Source: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html
# Source: https://machinelearningmastery.com/k-fold-cross-validation/
from sklearn.model_selection import KFold
kf = KFold(n_splits=6, shuffle=True, random_state=1) # 6 splits resulted in ~80/20 split
result = next(kf.split(bc_df), None) # returns indices of test/train instances
bc_df_train = bc_df.iloc[result[0]]
bc_df_test = bc_df.iloc[result[1]]
print(bc_df_train.shape)
print(bc_df_test.shape)
(806, 27) (162, 27)
2. Modeling¶
As previously mentioned, I'm going to be setting up all of my models using FeatureSpaces. This will give me an easy way to re-use configurations as necessary. I'm going to be creating 3 models in this first section, all of which are deep and wide representations.
from sklearn import metrics as mt
import tensorflow as tf
from tensorflow import keras
print(tf.__version__)
print(keras.__version__)
2.12.0 2.12.0
from tensorflow.keras.layers import Dense, Activation, Input
from tensorflow.keras.layers import Embedding, Concatenate, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.utils import FeatureSpace
# Adding a function to create a tensorflow dataset from dataframe
# Source: modified from in class lecture/notebook to align with my dataset
def create_dataset_from_dataframe(df_input):
df = df_input.copy()
labels = df['death_from_cancer']
df = {key: value.values[:,np.newaxis] for key, value in df_input[categorical_headers+numeric_headers].items()}
# print(df) # debug
# create the Dataset here
ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))
# now enable batching and prefetching
ds = ds.batch(batch_size)
ds = ds.prefetch(batch_size)
return ds
ds_train = create_dataset_from_dataframe(bc_df_train)
ds_test = create_dataset_from_dataframe(bc_df_test)
# Adding a function to create embeddings from the tensors
# Source: modified from in class lecture/notebook to align with my dataset
from tensorflow.keras.layers import Embedding, Flatten
def setup_embedding_from_categorical(feature_space, col_name):
# what the maximum integer value for this variable?
# which is the same as the number of categories
N = len(feature_space.preprocessors[col_name].get_vocabulary())
# get the output from the feature space, which is input to embedding
x = feature_space.preprocessors[col_name].output
# now use an embedding to deal with integers from feature space
x = Embedding(input_dim=N,
output_dim=int(np.sqrt(N)),
input_length=1, name=col_name+'_embed')(x)
x = Flatten()(x) # get rid of that pesky extra dimension (for time of embedding)
return x # return the tensor here
#Source: Modified from in-class lecture
def setup_embedding_from_crossing(feature_space, col_name):
# what the maximum integer value for this variable?
# get the size of the feature
N = feature_space.crossers[col_name].num_bins
x = feature_space.crossers[col_name].output
# now use an embedding to deal with integers as if they were one hot encoded
x = Embedding(input_dim=N,
output_dim=int(np.sqrt(N)),
input_length=1, name=col_name+'_embed')(x)
x = Flatten()(x) # get rid of that pesky extra dimension (for time of embedding)
return x
# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
# Setup functions to allow for F1 calculation
# Note I found this functionality to be depricated in my version of Keras, so it required a manual implementation
from keras import backend as K
def recall_m(y_true, y_pred):
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
recall = true_positives / (possible_positives + K.epsilon())
return recall
def precision_m(y_true, y_pred):
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
precision = true_positives / (predicted_positives + K.epsilon())
return precision
def f1_m(y_true, y_pred):
precision = precision_m(y_true, y_pred)
recall = recall_m(y_true, y_pred)
return 2*((precision*recall)/(precision+recall+K.epsilon()))
Now I can setup my models for Keras.
Model 1 of 3¶
I'll start all of my models by either setting up or re-using a FeatureSpace.
# Source: Modified from in-class lecture to match my dataset
from tensorflow.keras.utils import FeatureSpace
feature_space_1 = FeatureSpace(
features={
# Categorical feature encoded as string
"type_of_breast_surgery": FeatureSpace.string_categorical(num_oov_indices=0),
"cancer_type_detailed": FeatureSpace.string_categorical(num_oov_indices=0),
"cellularity": FeatureSpace.string_categorical(num_oov_indices=0),
"pam50_plus_claudin-low_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
"er_status_measured_by_ihc": FeatureSpace.string_categorical(num_oov_indices=0),
"er_status": FeatureSpace.string_categorical(num_oov_indices=0),
"her2_status_measured_by_snp6": FeatureSpace.string_categorical(num_oov_indices=0),
"her2_status": FeatureSpace.string_categorical(num_oov_indices=0),
"tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
# "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
"inferred_menopausal_state": FeatureSpace.string_categorical(num_oov_indices=0),
"integrative_cluster": FeatureSpace.string_categorical(num_oov_indices=0),
"primary_tumor_laterality": FeatureSpace.string_categorical(num_oov_indices=0),
"oncotree_code": FeatureSpace.string_categorical(num_oov_indices=0),
"pr_status": FeatureSpace.string_categorical(num_oov_indices=0),
"3-gene_classifier_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
# "chemotherapy": FeatureSpace.string_categorical(num_oov_indices=0),
# "hormone_therapy": FeatureSpace.string_categorical(num_oov_indices=0),
# "radio_therapy": FeatureSpace.string_categorical(num_oov_indices=0),
# Numerical features to normalize (normalization will be learned)
# learns the mean, variance, and if to invert
"chemotherapy": FeatureSpace.float_normalized(),
"hormone_therapy": FeatureSpace.float_normalized(),
"radio_therapy": FeatureSpace.float_normalized(),
"age_at_diagnosis": FeatureSpace.float_normalized(),
"neoplasm_histologic_grade": FeatureSpace.float_normalized(),
"lymph_nodes_examined_positive": FeatureSpace.float_normalized(),
"mutation_count": FeatureSpace.float_normalized(),
"nottingham_prognostic_index": FeatureSpace.float_normalized(),
"overall_survival_months": FeatureSpace.float_normalized(),
"tumor_size": FeatureSpace.float_normalized(),
"tumor_stage": FeatureSpace.float_normalized(),
},
# Specify feature cross with a custom crossing dim
crosses=[
FeatureSpace.cross(
feature_names=('her2_status_measured_by_snp6','her2_status'),
crossing_dim=4*2),
FeatureSpace.cross(
feature_names=('3-gene_classifier_subtype', 'integrative_cluster'),
crossing_dim=4*11),
FeatureSpace.cross(
feature_names=('er_status', 'er_status_measured_by_ihc'),
crossing_dim=2*2),
],
output_mode="concat",
)
# now that we have specified the preprocessing, let's run it on the data
# create a version of the dataset that can be iterated without labels
train_ds_with_no_labels = ds_train.map(lambda x, _: x)
feature_space_1.adapt(train_ds_with_no_labels) # inititalize the feature map to this data
# the adapt function allows the model to learn one-hot encoding sizes
# I won't be using the pre-processed portion in my models, but I'll need it later
# now define a preprocessing operation that returns the processed features
preprocessed_ds_train = ds_train.map(lambda x, y: (feature_space_1(x), y),
num_parallel_calls=tf.data.AUTOTUNE)
# run it so that we can use the pre-processed data
preprocessed_ds_train = preprocessed_ds_train.prefetch(tf.data.AUTOTUNE)
# do the same for the test set
preprocessed_ds_test = ds_test.map(lambda x, y: (feature_space_1(x), y), num_parallel_calls=tf.data.AUTOTUNE)
preprocessed_ds_test = preprocessed_ds_test.prefetch(tf.data.AUTOTUNE)
# Source: Modified from in-class lecture to match my dataset
dict_inputs = feature_space_1.get_inputs() # need to use unprocessed features here, to gain access to each output
# we need to create separate lists for each branch
crossed_outputs = []
# for each crossed variable, make an embedding
for col in feature_space_1.crossers.keys():
x = setup_embedding_from_crossing(feature_space_1, col)
# save these outputs in list to concatenate later
crossed_outputs.append(x)
# now concatenate the outputs and add a fully connected layer
wide_branch = Concatenate(name='wide_concat')(crossed_outputs)
# reset this input branch
all_deep_branch_outputs = []
# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
x = feature_space_1.preprocessors[col].output
x = tf.cast(x,float) # cast an integer as a float here
all_deep_branch_outputs.append(x)
# for each categorical variable
for col in categorical_headers:
# get the output tensor from ebedding layer
x = setup_embedding_from_categorical(feature_space_1, col)
# save these outputs in list to concatenate later
all_deep_branch_outputs.append(x)
# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=10,activation='relu', name='deep3')(deep_branch)
# merge the deep and wide branch
final_branch = Concatenate(name='concat_deep_wide')([deep_branch, wide_branch])
final_branch = Dense(units=1,activation='sigmoid',
name='combined')(final_branch)
training_model_1 = keras.Model(inputs=dict_inputs, outputs=final_branch)
# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_1.compile(
optimizer="adam", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)
training_model_1.summary()
plot_model(
training_model_1, to_file='model.png', show_shapes=True, show_layer_names=True,
rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_35"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
type_of_breast_surgery (InputL [(None, 1)] 0 []
ayer)
cancer_type_detailed (InputLay [(None, 1)] 0 []
er)
cellularity (InputLayer) [(None, 1)] 0 []
pam50_plus_claudin-low_subtype [(None, 1)] 0 []
(InputLayer)
er_status_measured_by_ihc (Inp [(None, 1)] 0 []
utLayer)
er_status (InputLayer) [(None, 1)] 0 []
her2_status_measured_by_snp6 ( [(None, 1)] 0 []
InputLayer)
her2_status (InputLayer) [(None, 1)] 0 []
tumor_other_histologic_subtype [(None, 1)] 0 []
(InputLayer)
inferred_menopausal_state (Inp [(None, 1)] 0 []
utLayer)
integrative_cluster (InputLaye [(None, 1)] 0 []
r)
primary_tumor_laterality (Inpu [(None, 1)] 0 []
tLayer)
oncotree_code (InputLayer) [(None, 1)] 0 []
pr_status (InputLayer) [(None, 1)] 0 []
3-gene_classifier_subtype (Inp [(None, 1)] 0 []
utLayer)
age_at_diagnosis (InputLayer) [(None, 1)] 0 []
neoplasm_histologic_grade (Inp [(None, 1)] 0 []
utLayer)
lymph_nodes_examined_positive [(None, 1)] 0 []
(InputLayer)
mutation_count (InputLayer) [(None, 1)] 0 []
nottingham_prognostic_index (I [(None, 1)] 0 []
nputLayer)
overall_survival_months (Input [(None, 1)] 0 []
Layer)
tumor_size (InputLayer) [(None, 1)] 0 []
tumor_stage (InputLayer) [(None, 1)] 0 []
chemotherapy (InputLayer) [(None, 1)] 0 []
radio_therapy (InputLayer) [(None, 1)] 0 []
hormone_therapy (InputLayer) [(None, 1)] 0 []
string_categorical_424_preproc (None, 1) 0 ['type_of_breast_surgery[0][0]']
essor (StringLookup)
string_categorical_425_preproc (None, 1) 0 ['cancer_type_detailed[0][0]']
essor (StringLookup)
string_categorical_426_preproc (None, 1) 0 ['cellularity[0][0]']
essor (StringLookup)
string_categorical_427_preproc (None, 1) 0 ['pam50_plus_claudin-low_subtype[
essor (StringLookup) 0][0]']
string_categorical_428_preproc (None, 1) 0 ['er_status_measured_by_ihc[0][0]
essor (StringLookup) ']
string_categorical_429_preproc (None, 1) 0 ['er_status[0][0]']
essor (StringLookup)
string_categorical_430_preproc (None, 1) 0 ['her2_status_measured_by_snp6[0]
essor (StringLookup) [0]']
string_categorical_431_preproc (None, 1) 0 ['her2_status[0][0]']
essor (StringLookup)
string_categorical_432_preproc (None, 1) 0 ['tumor_other_histologic_subtype[
essor (StringLookup) 0][0]']
string_categorical_433_preproc (None, 1) 0 ['inferred_menopausal_state[0][0]
essor (StringLookup) ']
string_categorical_434_preproc (None, 1) 0 ['integrative_cluster[0][0]']
essor (StringLookup)
string_categorical_435_preproc (None, 1) 0 ['primary_tumor_laterality[0][0]'
essor (StringLookup) ]
string_categorical_436_preproc (None, 1) 0 ['oncotree_code[0][0]']
essor (StringLookup)
string_categorical_437_preproc (None, 1) 0 ['pr_status[0][0]']
essor (StringLookup)
string_categorical_438_preproc (None, 1) 0 ['3-gene_classifier_subtype[0][0]
essor (StringLookup) ']
float_normalized_306_preproces (None, 1) 3 ['age_at_diagnosis[0][0]']
sor (Normalization)
float_normalized_307_preproces (None, 1) 3 ['neoplasm_histologic_grade[0][0]
sor (Normalization) ']
float_normalized_308_preproces (None, 1) 3 ['lymph_nodes_examined_positive[0
sor (Normalization) ][0]']
float_normalized_309_preproces (None, 1) 3 ['mutation_count[0][0]']
sor (Normalization)
float_normalized_310_preproces (None, 1) 3 ['nottingham_prognostic_index[0][
sor (Normalization) 0]']
float_normalized_311_preproces (None, 1) 3 ['overall_survival_months[0][0]']
sor (Normalization)
float_normalized_312_preproces (None, 1) 3 ['tumor_size[0][0]']
sor (Normalization)
float_normalized_313_preproces (None, 1) 3 ['tumor_stage[0][0]']
sor (Normalization)
float_normalized_303_preproces (None, 1) 3 ['chemotherapy[0][0]']
sor (Normalization)
float_normalized_305_preproces (None, 1) 3 ['radio_therapy[0][0]']
sor (Normalization)
float_normalized_304_preproces (None, 1) 3 ['hormone_therapy[0][0]']
sor (Normalization)
type_of_breast_surgery_embed ( (None, 1, 1) 2 ['string_categorical_424_preproce
Embedding) ssor[0][0]']
cancer_type_detailed_embed (Em (None, 1, 2) 10 ['string_categorical_425_preproce
bedding) ssor[0][0]']
cellularity_embed (Embedding) (None, 1, 1) 3 ['string_categorical_426_preproce
ssor[0][0]']
pam50_plus_claudin-low_subtype (None, 1, 2) 14 ['string_categorical_427_preproce
_embed (Embedding) ssor[0][0]']
er_status_measured_by_ihc_embe (None, 1, 1) 2 ['string_categorical_428_preproce
d (Embedding) ssor[0][0]']
er_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_429_preproce
ssor[0][0]']
her2_status_measured_by_snp6_e (None, 1, 2) 8 ['string_categorical_430_preproce
mbed (Embedding) ssor[0][0]']
her2_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_431_preproce
ssor[0][0]']
tumor_other_histologic_subtype (None, 1, 2) 14 ['string_categorical_432_preproce
_embed (Embedding) ssor[0][0]']
inferred_menopausal_state_embe (None, 1, 1) 2 ['string_categorical_433_preproce
d (Embedding) ssor[0][0]']
integrative_cluster_embed (Emb (None, 1, 3) 33 ['string_categorical_434_preproce
edding) ssor[0][0]']
primary_tumor_laterality_embed (None, 1, 1) 2 ['string_categorical_435_preproce
(Embedding) ssor[0][0]']
oncotree_code_embed (Embedding (None, 1, 2) 10 ['string_categorical_436_preproce
) ssor[0][0]']
pr_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_437_preproce
ssor[0][0]']
3-gene_classifier_subtype_embe (None, 1, 2) 8 ['string_categorical_438_preproce
d (Embedding) ssor[0][0]']
tf.cast_385 (TFOpLambda) (None, 1) 0 ['float_normalized_306_preprocess
or[0][0]']
tf.cast_386 (TFOpLambda) (None, 1) 0 ['float_normalized_307_preprocess
or[0][0]']
tf.cast_387 (TFOpLambda) (None, 1) 0 ['float_normalized_308_preprocess
or[0][0]']
tf.cast_388 (TFOpLambda) (None, 1) 0 ['float_normalized_309_preprocess
or[0][0]']
tf.cast_389 (TFOpLambda) (None, 1) 0 ['float_normalized_310_preprocess
or[0][0]']
tf.cast_390 (TFOpLambda) (None, 1) 0 ['float_normalized_311_preprocess
or[0][0]']
tf.cast_391 (TFOpLambda) (None, 1) 0 ['float_normalized_312_preprocess
or[0][0]']
tf.cast_392 (TFOpLambda) (None, 1) 0 ['float_normalized_313_preprocess
or[0][0]']
tf.cast_393 (TFOpLambda) (None, 1) 0 ['float_normalized_303_preprocess
or[0][0]']
tf.cast_394 (TFOpLambda) (None, 1) 0 ['float_normalized_305_preprocess
or[0][0]']
tf.cast_395 (TFOpLambda) (None, 1) 0 ['float_normalized_304_preprocess
or[0][0]']
flatten_620 (Flatten) (None, 1) 0 ['type_of_breast_surgery_embed[0]
[0]']
flatten_621 (Flatten) (None, 2) 0 ['cancer_type_detailed_embed[0][0
]']
flatten_622 (Flatten) (None, 1) 0 ['cellularity_embed[0][0]']
flatten_623 (Flatten) (None, 2) 0 ['pam50_plus_claudin-low_subtype_
embed[0][0]']
flatten_624 (Flatten) (None, 1) 0 ['er_status_measured_by_ihc_embed
[0][0]']
flatten_625 (Flatten) (None, 1) 0 ['er_status_embed[0][0]']
flatten_626 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_em
bed[0][0]']
flatten_627 (Flatten) (None, 1) 0 ['her2_status_embed[0][0]']
flatten_628 (Flatten) (None, 2) 0 ['tumor_other_histologic_subtype_
embed[0][0]']
flatten_629 (Flatten) (None, 1) 0 ['inferred_menopausal_state_embed
[0][0]']
flatten_630 (Flatten) (None, 3) 0 ['integrative_cluster_embed[0][0]
']
flatten_631 (Flatten) (None, 1) 0 ['primary_tumor_laterality_embed[
0][0]']
flatten_632 (Flatten) (None, 2) 0 ['oncotree_code_embed[0][0]']
flatten_633 (Flatten) (None, 1) 0 ['pr_status_embed[0][0]']
flatten_634 (Flatten) (None, 2) 0 ['3-gene_classifier_subtype_embed
[0][0]']
embed_concat (Concatenate) (None, 34) 0 ['tf.cast_385[0][0]',
'tf.cast_386[0][0]',
'tf.cast_387[0][0]',
'tf.cast_388[0][0]',
'tf.cast_389[0][0]',
'tf.cast_390[0][0]',
'tf.cast_391[0][0]',
'tf.cast_392[0][0]',
'tf.cast_393[0][0]',
'tf.cast_394[0][0]',
'tf.cast_395[0][0]',
'flatten_620[0][0]',
'flatten_621[0][0]',
'flatten_622[0][0]',
'flatten_623[0][0]',
'flatten_624[0][0]',
'flatten_625[0][0]',
'flatten_626[0][0]',
'flatten_627[0][0]',
'flatten_628[0][0]',
'flatten_629[0][0]',
'flatten_630[0][0]',
'flatten_631[0][0]',
'flatten_632[0][0]',
'flatten_633[0][0]',
'flatten_634[0][0]']
her2_status_measured_by_snp6_X (None, 1) 0 ['string_categorical_430_preproce
_her2_status (HashedCrossing) ssor[0][0]',
'string_categorical_431_preproce
ssor[0][0]']
3-gene_classifier_subtype_X_in (None, 1) 0 ['string_categorical_438_preproce
tegrative_cluster (HashedCross ssor[0][0]',
ing) 'string_categorical_434_preproce
ssor[0][0]']
er_status_X_er_status_measured (None, 1) 0 ['string_categorical_429_preproce
_by_ihc (HashedCrossing) ssor[0][0]',
'string_categorical_428_preproce
ssor[0][0]']
deep1 (Dense) (None, 34) 1190 ['embed_concat[0][0]']
her2_status_measured_by_snp6_X (None, 1, 2) 16 ['her2_status_measured_by_snp6_X_
_her2_status_embed (Embedding) her2_status[0][0]']
3-gene_classifier_subtype_X_in (None, 1, 6) 264 ['3-gene_classifier_subtype_X_int
tegrative_cluster_embed (Embed egrative_cluster[0][0]']
ding)
er_status_X_er_status_measured (None, 1, 2) 8 ['er_status_X_er_status_measured_
_by_ihc_embed (Embedding) by_ihc[0][0]']
deep2 (Dense) (None, 17) 595 ['deep1[0][0]']
flatten_617 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_X_
her2_status_embed[0][0]']
flatten_618 (Flatten) (None, 6) 0 ['3-gene_classifier_subtype_X_int
egrative_cluster_embed[0][0]']
flatten_619 (Flatten) (None, 2) 0 ['er_status_X_er_status_measured_
by_ihc_embed[0][0]']
deep3 (Dense) (None, 10) 180 ['deep2[0][0]']
wide_concat (Concatenate) (None, 10) 0 ['flatten_617[0][0]',
'flatten_618[0][0]',
'flatten_619[0][0]']
concat_deep_wide (Concatenate) (None, 20) 0 ['deep3[0][0]',
'wide_concat[0][0]']
combined (Dense) (None, 1) 21 ['concat_deep_wide[0][0]']
==================================================================================================
Total params: 2,421
Trainable params: 2,388
Non-trainable params: 33
__________________________________________________________________________________________________
# Train the model
history_1 = training_model_1.fit(
ds_train, epochs=25, validation_data=ds_test, verbose=2
)
Epoch 1/25 19/19 - 4s - loss: 0.6937 - acc: 0.5211 - f1_m: 0.1378 - precision_m: 0.5921 - recall_m: 0.0799 - val_loss: 0.6872 - val_acc: 0.5247 - val_f1_m: 0.2440 - val_precision_m: 0.6726 - val_recall_m: 0.1592 - 4s/epoch - 236ms/step Epoch 2/25 19/19 - 0s - loss: 0.6610 - acc: 0.6563 - f1_m: 0.4916 - precision_m: 0.8188 - recall_m: 0.3629 - val_loss: 0.6650 - val_acc: 0.6481 - val_f1_m: 0.5436 - val_precision_m: 0.7738 - val_recall_m: 0.4575 - 73ms/epoch - 4ms/step Epoch 3/25 19/19 - 0s - loss: 0.6370 - acc: 0.7072 - f1_m: 0.6399 - precision_m: 0.7814 - recall_m: 0.5686 - val_loss: 0.6426 - val_acc: 0.6914 - val_f1_m: 0.6419 - val_precision_m: 0.7453 - val_recall_m: 0.5908 - 71ms/epoch - 4ms/step Epoch 4/25 19/19 - 0s - loss: 0.6084 - acc: 0.7233 - f1_m: 0.6792 - precision_m: 0.7612 - recall_m: 0.6501 - val_loss: 0.6136 - val_acc: 0.7099 - val_f1_m: 0.6848 - val_precision_m: 0.7398 - val_recall_m: 0.6750 - 74ms/epoch - 4ms/step Epoch 5/25 19/19 - 0s - loss: 0.5751 - acc: 0.7531 - f1_m: 0.7241 - precision_m: 0.7796 - recall_m: 0.7207 - val_loss: 0.5801 - val_acc: 0.7160 - val_f1_m: 0.6959 - val_precision_m: 0.7464 - val_recall_m: 0.6958 - 72ms/epoch - 4ms/step Epoch 6/25 19/19 - 0s - loss: 0.5426 - acc: 0.7568 - f1_m: 0.7268 - precision_m: 0.7738 - recall_m: 0.7275 - val_loss: 0.5459 - val_acc: 0.7469 - val_f1_m: 0.7314 - val_precision_m: 0.7767 - val_recall_m: 0.7258 - 76ms/epoch - 4ms/step Epoch 7/25 19/19 - 0s - loss: 0.5143 - acc: 0.7655 - f1_m: 0.7351 - precision_m: 0.7764 - recall_m: 0.7379 - val_loss: 0.5161 - val_acc: 0.7531 - val_f1_m: 0.7438 - val_precision_m: 0.7765 - val_recall_m: 0.7567 - 109ms/epoch - 6ms/step Epoch 8/25 19/19 - 0s - loss: 0.4910 - acc: 0.7742 - f1_m: 0.7469 - precision_m: 0.7830 - recall_m: 0.7535 - val_loss: 0.4912 - val_acc: 0.7716 - val_f1_m: 0.7688 - val_precision_m: 0.7714 - val_recall_m: 0.7992 - 72ms/epoch - 4ms/step Epoch 9/25 19/19 - 0s - loss: 0.4716 - acc: 0.7816 - f1_m: 0.7618 - precision_m: 0.7834 - recall_m: 0.7772 - val_loss: 0.4714 - val_acc: 0.7963 - val_f1_m: 0.7937 - val_precision_m: 0.8052 - val_recall_m: 0.8200 - 90ms/epoch - 5ms/step Epoch 10/25 19/19 - 0s - loss: 0.4562 - acc: 0.7965 - f1_m: 0.7793 - precision_m: 0.7951 - recall_m: 0.7942 - val_loss: 0.4565 - val_acc: 0.8148 - val_f1_m: 0.8104 - val_precision_m: 0.8170 - val_recall_m: 0.8400 - 77ms/epoch - 4ms/step Epoch 11/25 19/19 - 0s - loss: 0.4438 - acc: 0.8040 - f1_m: 0.7893 - precision_m: 0.7950 - recall_m: 0.8138 - val_loss: 0.4461 - val_acc: 0.8148 - val_f1_m: 0.8104 - val_precision_m: 0.8170 - val_recall_m: 0.8400 - 69ms/epoch - 4ms/step Epoch 12/25 19/19 - 0s - loss: 0.4342 - acc: 0.8040 - f1_m: 0.7893 - precision_m: 0.7951 - recall_m: 0.8140 - val_loss: 0.4388 - val_acc: 0.8148 - val_f1_m: 0.8104 - val_precision_m: 0.8222 - val_recall_m: 0.8400 - 71ms/epoch - 4ms/step Epoch 13/25 19/19 - 0s - loss: 0.4268 - acc: 0.8065 - f1_m: 0.7906 - precision_m: 0.7976 - recall_m: 0.8120 - val_loss: 0.4338 - val_acc: 0.8148 - val_f1_m: 0.8104 - val_precision_m: 0.8222 - val_recall_m: 0.8400 - 74ms/epoch - 4ms/step Epoch 14/25 19/19 - 0s - loss: 0.4203 - acc: 0.8077 - f1_m: 0.7915 - precision_m: 0.7980 - recall_m: 0.8120 - val_loss: 0.4302 - val_acc: 0.8210 - val_f1_m: 0.8157 - val_precision_m: 0.8287 - val_recall_m: 0.8400 - 77ms/epoch - 4ms/step Epoch 15/25 19/19 - 0s - loss: 0.4142 - acc: 0.8089 - f1_m: 0.7926 - precision_m: 0.7997 - recall_m: 0.8115 - val_loss: 0.4283 - val_acc: 0.8272 - val_f1_m: 0.8214 - val_precision_m: 0.8359 - val_recall_m: 0.8400 - 71ms/epoch - 4ms/step Epoch 16/25 19/19 - 0s - loss: 0.4089 - acc: 0.8139 - f1_m: 0.7970 - precision_m: 0.8023 - recall_m: 0.8157 - val_loss: 0.4272 - val_acc: 0.8272 - val_f1_m: 0.8214 - val_precision_m: 0.8359 - val_recall_m: 0.8400 - 73ms/epoch - 4ms/step Epoch 17/25 19/19 - 0s - loss: 0.4045 - acc: 0.8164 - f1_m: 0.8005 - precision_m: 0.8061 - recall_m: 0.8193 - val_loss: 0.4266 - val_acc: 0.8272 - val_f1_m: 0.8214 - val_precision_m: 0.8359 - val_recall_m: 0.8400 - 72ms/epoch - 4ms/step Epoch 18/25 19/19 - 0s - loss: 0.4001 - acc: 0.8176 - f1_m: 0.8030 - precision_m: 0.8072 - recall_m: 0.8235 - val_loss: 0.4264 - val_acc: 0.8210 - val_f1_m: 0.8146 - val_precision_m: 0.8345 - val_recall_m: 0.8300 - 75ms/epoch - 4ms/step Epoch 19/25 19/19 - 0s - loss: 0.3957 - acc: 0.8213 - f1_m: 0.8063 - precision_m: 0.8115 - recall_m: 0.8235 - val_loss: 0.4265 - val_acc: 0.8210 - val_f1_m: 0.8146 - val_precision_m: 0.8345 - val_recall_m: 0.8300 - 72ms/epoch - 4ms/step Epoch 20/25 19/19 - 0s - loss: 0.3919 - acc: 0.8213 - f1_m: 0.8063 - precision_m: 0.8115 - recall_m: 0.8235 - val_loss: 0.4266 - val_acc: 0.8210 - val_f1_m: 0.8146 - val_precision_m: 0.8345 - val_recall_m: 0.8300 - 73ms/epoch - 4ms/step Epoch 21/25 19/19 - 0s - loss: 0.3878 - acc: 0.8251 - f1_m: 0.8097 - precision_m: 0.8150 - recall_m: 0.8261 - val_loss: 0.4267 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 71ms/epoch - 4ms/step Epoch 22/25 19/19 - 0s - loss: 0.3840 - acc: 0.8288 - f1_m: 0.8131 - precision_m: 0.8187 - recall_m: 0.8289 - val_loss: 0.4269 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 78ms/epoch - 4ms/step Epoch 23/25 19/19 - 0s - loss: 0.3805 - acc: 0.8300 - f1_m: 0.8161 - precision_m: 0.8186 - recall_m: 0.8365 - val_loss: 0.4273 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 71ms/epoch - 4ms/step Epoch 24/25 19/19 - 0s - loss: 0.3766 - acc: 0.8325 - f1_m: 0.8178 - precision_m: 0.8224 - recall_m: 0.8340 - val_loss: 0.4277 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 87ms/epoch - 5ms/step Epoch 25/25 19/19 - 0s - loss: 0.3734 - acc: 0.8325 - f1_m: 0.8180 - precision_m: 0.8233 - recall_m: 0.8353 - val_loss: 0.4279 - val_acc: 0.8148 - val_f1_m: 0.8092 - val_precision_m: 0.8228 - val_recall_m: 0.8300 - 78ms/epoch - 4ms/step
# Print plots of metrics
from matplotlib import pyplot as plt
%matplotlib inline
plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_1.history['f1_m'])
plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_1.history['val_f1_m'])
plt.title('Validation')
plt.subplot(2,2,3)
plt.plot(history_1.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')
plt.subplot(2,2,4)
plt.plot(history_1.history['val_loss'])
plt.xlabel('epochs')
Text(0.5, 0, 'epochs')
I see convergence in this model at around 20-25 epochs. F1 score is good on validation and training data as well. Ideally I'd like to see my training loss get lower than this as I want it to be as low as possible. Next I'll check the confusion matrix to see my ratio of True Positives, True Negatives, False Positives, and False Negatives.
# Vizualize some metrics associated with this model
# Source: Modified from in-class lecture
# Use the sklearn metrics here, if you want to
from sklearn import metrics as mt
y_test = tf.concat([y for x, y in ds_test], axis=0)
y_test = y_test.numpy()
# now lets see how well the model performed
yhat_proba_1 = training_model_1.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions
yhat_1 = np.round(yhat_proba_1.squeeze()) # round to get binary class
conf_mat_1 = mt.confusion_matrix(y_test, yhat_1)
print(conf_mat_1)
print(mt.classification_report(y_test,yhat_1))
# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309. VitalBook file.
# Create pandas dataframe
conf_df_1 = pd.DataFrame(conf_mat_1, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())
# Create heatmap
sns.heatmap(conf_df_1, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 1s 3ms/step
[[66 14]
[16 66]]
precision recall f1-score support
0 0.80 0.82 0.81 80
1 0.82 0.80 0.81 82
accuracy 0.81 162
macro avg 0.81 0.81 0.81 162
weighted avg 0.82 0.81 0.81 162
Not a bad result. My True Positives and True Negatives are relatively high compared with False Positives and False Negatives. However I'd really like to see fewer false positives. As I stated above they can be bad for patient well-being due to the decisions they may drive. We'll see if we can improve on this in subsequent models.
Model 2 of 3¶
This model changes some of the cross features to see if it impacts results. In model 2, I'm going to see if I can try some of the other feature space arrangements discussed above. So I'll use the following:
- her2_status_measured_by_snp6 and her2_status (old)
- 3-gene_classifier, integrative_cluster, pam50_plus (new)
- er_status and er_status_measured_by_ihc (old)
# Source: Modified from in-class lecture to match my dataset
from tensorflow.keras.utils import FeatureSpace
feature_space_2 = FeatureSpace(
features={
# Categorical feature encoded as string
"type_of_breast_surgery": FeatureSpace.string_categorical(num_oov_indices=0),
"cancer_type_detailed": FeatureSpace.string_categorical(num_oov_indices=0),
"cellularity": FeatureSpace.string_categorical(num_oov_indices=0),
"pam50_plus_claudin-low_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
"er_status_measured_by_ihc": FeatureSpace.string_categorical(num_oov_indices=0),
"er_status": FeatureSpace.string_categorical(num_oov_indices=0),
"her2_status_measured_by_snp6": FeatureSpace.string_categorical(num_oov_indices=0),
"her2_status": FeatureSpace.string_categorical(num_oov_indices=0),
"tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
# "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
"inferred_menopausal_state": FeatureSpace.string_categorical(num_oov_indices=0),
"integrative_cluster": FeatureSpace.string_categorical(num_oov_indices=0),
"primary_tumor_laterality": FeatureSpace.string_categorical(num_oov_indices=0),
"oncotree_code": FeatureSpace.string_categorical(num_oov_indices=0),
"pr_status": FeatureSpace.string_categorical(num_oov_indices=0),
"3-gene_classifier_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
# Numerical features to normalize (normalization will be learned)
# learns the mean, variance, and if to invert
"chemotherapy": FeatureSpace.float_normalized(),
"hormone_therapy": FeatureSpace.float_normalized(),
"radio_therapy": FeatureSpace.float_normalized(),
"age_at_diagnosis": FeatureSpace.float_normalized(),
"neoplasm_histologic_grade": FeatureSpace.float_normalized(),
"lymph_nodes_examined_positive": FeatureSpace.float_normalized(),
"mutation_count": FeatureSpace.float_normalized(),
"nottingham_prognostic_index": FeatureSpace.float_normalized(),
"overall_survival_months": FeatureSpace.float_normalized(),
"tumor_size": FeatureSpace.float_normalized(),
"tumor_stage": FeatureSpace.float_normalized(),
},
# Specify feature cross with a custom crossing dim
crosses=[
FeatureSpace.cross(
feature_names=('her2_status_measured_by_snp6','her2_status'),
crossing_dim=4*2),
FeatureSpace.cross(
feature_names=('3-gene_classifier_subtype', 'integrative_cluster', 'pam50_plus_claudin-low_subtype'),
crossing_dim=4*11*7),
FeatureSpace.cross(
feature_names=('er_status', 'er_status_measured_by_ihc'),
crossing_dim=2*2),
],
output_mode="concat",
)
# now that we have specified the preprocessing, let's run it on the data
# create a version of the dataset that can be iterated without labels
train_ds_with_no_labels = ds_train.map(lambda x, _: x)
feature_space_2.adapt(train_ds_with_no_labels) # inititalize the feature map to this data
# the adapt function allows the model to learn one-hot encoding sizes
# now define a preprocessing operation that returns the processed features
# preprocessed_ds_train = ds_train.map(lambda x, y: (feature_space_2(x), y),
# num_parallel_calls=tf.data.AUTOTUNE)
# # run it so that we can use the pre-processed data
# preprocessed_ds_train = preprocessed_ds_train.prefetch(tf.data.AUTOTUNE)
# # do the same for the test set
# preprocessed_ds_test = ds_test.map(lambda x, y: (feature_space_2(x), y), num_parallel_calls=tf.data.AUTOTUNE)
# preprocessed_ds_test = preprocessed_ds_test.prefetch(tf.data.AUTOTUNE)
# from keras.metrics import Precision, Recall
dict_inputs = feature_space_2.get_inputs() # need to use unprocessed features here, to gain access to each output
# we need to create separate lists for each branch
crossed_outputs = []
# for each crossed variable, make an embedding
for col in feature_space_2.crossers.keys():
x = setup_embedding_from_crossing(feature_space_2, col)
# save these outputs in list to concatenate later
crossed_outputs.append(x)
# now concatenate the outputs and add a fully connected layer
wide_branch = Concatenate(name='wide_concat')(crossed_outputs)
# reset this input branch
all_deep_branch_outputs = []
# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
x = feature_space_2.preprocessors[col].output
x = tf.cast(x,float) # cast an integer as a float here
all_deep_branch_outputs.append(x)
# for each categorical variable
for col in categorical_headers:
# get the output tensor from ebedding layer
x = setup_embedding_from_categorical(feature_space_2, col)
# save these outputs in list to concatenate later
all_deep_branch_outputs.append(x)
# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=10,activation='relu', name='deep3')(deep_branch)
# merge the deep and wide branch
final_branch = Concatenate(name='concat_deep_wide')([deep_branch, wide_branch])
final_branch = Dense(units=1,activation='sigmoid',
name='combined')(final_branch)
training_model_2 = keras.Model(inputs=dict_inputs, outputs=final_branch)
# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_2.compile(
optimizer="adam", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)
training_model_2.summary()
plot_model(
training_model_2, to_file='model.png', show_shapes=True, show_layer_names=True,
rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_36"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
type_of_breast_surgery (InputL [(None, 1)] 0 []
ayer)
cancer_type_detailed (InputLay [(None, 1)] 0 []
er)
cellularity (InputLayer) [(None, 1)] 0 []
pam50_plus_claudin-low_subtype [(None, 1)] 0 []
(InputLayer)
er_status_measured_by_ihc (Inp [(None, 1)] 0 []
utLayer)
er_status (InputLayer) [(None, 1)] 0 []
her2_status_measured_by_snp6 ( [(None, 1)] 0 []
InputLayer)
her2_status (InputLayer) [(None, 1)] 0 []
tumor_other_histologic_subtype [(None, 1)] 0 []
(InputLayer)
inferred_menopausal_state (Inp [(None, 1)] 0 []
utLayer)
integrative_cluster (InputLaye [(None, 1)] 0 []
r)
primary_tumor_laterality (Inpu [(None, 1)] 0 []
tLayer)
oncotree_code (InputLayer) [(None, 1)] 0 []
pr_status (InputLayer) [(None, 1)] 0 []
3-gene_classifier_subtype (Inp [(None, 1)] 0 []
utLayer)
age_at_diagnosis (InputLayer) [(None, 1)] 0 []
neoplasm_histologic_grade (Inp [(None, 1)] 0 []
utLayer)
lymph_nodes_examined_positive [(None, 1)] 0 []
(InputLayer)
mutation_count (InputLayer) [(None, 1)] 0 []
nottingham_prognostic_index (I [(None, 1)] 0 []
nputLayer)
overall_survival_months (Input [(None, 1)] 0 []
Layer)
tumor_size (InputLayer) [(None, 1)] 0 []
tumor_stage (InputLayer) [(None, 1)] 0 []
chemotherapy (InputLayer) [(None, 1)] 0 []
radio_therapy (InputLayer) [(None, 1)] 0 []
hormone_therapy (InputLayer) [(None, 1)] 0 []
string_categorical_439_preproc (None, 1) 0 ['type_of_breast_surgery[0][0]']
essor (StringLookup)
string_categorical_440_preproc (None, 1) 0 ['cancer_type_detailed[0][0]']
essor (StringLookup)
string_categorical_441_preproc (None, 1) 0 ['cellularity[0][0]']
essor (StringLookup)
string_categorical_442_preproc (None, 1) 0 ['pam50_plus_claudin-low_subtype[
essor (StringLookup) 0][0]']
string_categorical_443_preproc (None, 1) 0 ['er_status_measured_by_ihc[0][0]
essor (StringLookup) ']
string_categorical_444_preproc (None, 1) 0 ['er_status[0][0]']
essor (StringLookup)
string_categorical_445_preproc (None, 1) 0 ['her2_status_measured_by_snp6[0]
essor (StringLookup) [0]']
string_categorical_446_preproc (None, 1) 0 ['her2_status[0][0]']
essor (StringLookup)
string_categorical_447_preproc (None, 1) 0 ['tumor_other_histologic_subtype[
essor (StringLookup) 0][0]']
string_categorical_448_preproc (None, 1) 0 ['inferred_menopausal_state[0][0]
essor (StringLookup) ']
string_categorical_449_preproc (None, 1) 0 ['integrative_cluster[0][0]']
essor (StringLookup)
string_categorical_450_preproc (None, 1) 0 ['primary_tumor_laterality[0][0]'
essor (StringLookup) ]
string_categorical_451_preproc (None, 1) 0 ['oncotree_code[0][0]']
essor (StringLookup)
string_categorical_452_preproc (None, 1) 0 ['pr_status[0][0]']
essor (StringLookup)
string_categorical_453_preproc (None, 1) 0 ['3-gene_classifier_subtype[0][0]
essor (StringLookup) ']
float_normalized_317_preproces (None, 1) 3 ['age_at_diagnosis[0][0]']
sor (Normalization)
float_normalized_318_preproces (None, 1) 3 ['neoplasm_histologic_grade[0][0]
sor (Normalization) ']
float_normalized_319_preproces (None, 1) 3 ['lymph_nodes_examined_positive[0
sor (Normalization) ][0]']
float_normalized_320_preproces (None, 1) 3 ['mutation_count[0][0]']
sor (Normalization)
float_normalized_321_preproces (None, 1) 3 ['nottingham_prognostic_index[0][
sor (Normalization) 0]']
float_normalized_322_preproces (None, 1) 3 ['overall_survival_months[0][0]']
sor (Normalization)
float_normalized_323_preproces (None, 1) 3 ['tumor_size[0][0]']
sor (Normalization)
float_normalized_324_preproces (None, 1) 3 ['tumor_stage[0][0]']
sor (Normalization)
float_normalized_314_preproces (None, 1) 3 ['chemotherapy[0][0]']
sor (Normalization)
float_normalized_316_preproces (None, 1) 3 ['radio_therapy[0][0]']
sor (Normalization)
float_normalized_315_preproces (None, 1) 3 ['hormone_therapy[0][0]']
sor (Normalization)
type_of_breast_surgery_embed ( (None, 1, 1) 2 ['string_categorical_439_preproce
Embedding) ssor[0][0]']
cancer_type_detailed_embed (Em (None, 1, 2) 10 ['string_categorical_440_preproce
bedding) ssor[0][0]']
cellularity_embed (Embedding) (None, 1, 1) 3 ['string_categorical_441_preproce
ssor[0][0]']
pam50_plus_claudin-low_subtype (None, 1, 2) 14 ['string_categorical_442_preproce
_embed (Embedding) ssor[0][0]']
er_status_measured_by_ihc_embe (None, 1, 1) 2 ['string_categorical_443_preproce
d (Embedding) ssor[0][0]']
er_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_444_preproce
ssor[0][0]']
her2_status_measured_by_snp6_e (None, 1, 2) 8 ['string_categorical_445_preproce
mbed (Embedding) ssor[0][0]']
her2_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_446_preproce
ssor[0][0]']
tumor_other_histologic_subtype (None, 1, 2) 14 ['string_categorical_447_preproce
_embed (Embedding) ssor[0][0]']
inferred_menopausal_state_embe (None, 1, 1) 2 ['string_categorical_448_preproce
d (Embedding) ssor[0][0]']
integrative_cluster_embed (Emb (None, 1, 3) 33 ['string_categorical_449_preproce
edding) ssor[0][0]']
primary_tumor_laterality_embed (None, 1, 1) 2 ['string_categorical_450_preproce
(Embedding) ssor[0][0]']
oncotree_code_embed (Embedding (None, 1, 2) 10 ['string_categorical_451_preproce
) ssor[0][0]']
pr_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_452_preproce
ssor[0][0]']
3-gene_classifier_subtype_embe (None, 1, 2) 8 ['string_categorical_453_preproce
d (Embedding) ssor[0][0]']
tf.cast_396 (TFOpLambda) (None, 1) 0 ['float_normalized_317_preprocess
or[0][0]']
tf.cast_397 (TFOpLambda) (None, 1) 0 ['float_normalized_318_preprocess
or[0][0]']
tf.cast_398 (TFOpLambda) (None, 1) 0 ['float_normalized_319_preprocess
or[0][0]']
tf.cast_399 (TFOpLambda) (None, 1) 0 ['float_normalized_320_preprocess
or[0][0]']
tf.cast_400 (TFOpLambda) (None, 1) 0 ['float_normalized_321_preprocess
or[0][0]']
tf.cast_401 (TFOpLambda) (None, 1) 0 ['float_normalized_322_preprocess
or[0][0]']
tf.cast_402 (TFOpLambda) (None, 1) 0 ['float_normalized_323_preprocess
or[0][0]']
tf.cast_403 (TFOpLambda) (None, 1) 0 ['float_normalized_324_preprocess
or[0][0]']
tf.cast_404 (TFOpLambda) (None, 1) 0 ['float_normalized_314_preprocess
or[0][0]']
tf.cast_405 (TFOpLambda) (None, 1) 0 ['float_normalized_316_preprocess
or[0][0]']
tf.cast_406 (TFOpLambda) (None, 1) 0 ['float_normalized_315_preprocess
or[0][0]']
flatten_638 (Flatten) (None, 1) 0 ['type_of_breast_surgery_embed[0]
[0]']
flatten_639 (Flatten) (None, 2) 0 ['cancer_type_detailed_embed[0][0
]']
flatten_640 (Flatten) (None, 1) 0 ['cellularity_embed[0][0]']
flatten_641 (Flatten) (None, 2) 0 ['pam50_plus_claudin-low_subtype_
embed[0][0]']
flatten_642 (Flatten) (None, 1) 0 ['er_status_measured_by_ihc_embed
[0][0]']
flatten_643 (Flatten) (None, 1) 0 ['er_status_embed[0][0]']
flatten_644 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_em
bed[0][0]']
flatten_645 (Flatten) (None, 1) 0 ['her2_status_embed[0][0]']
flatten_646 (Flatten) (None, 2) 0 ['tumor_other_histologic_subtype_
embed[0][0]']
flatten_647 (Flatten) (None, 1) 0 ['inferred_menopausal_state_embed
[0][0]']
flatten_648 (Flatten) (None, 3) 0 ['integrative_cluster_embed[0][0]
']
flatten_649 (Flatten) (None, 1) 0 ['primary_tumor_laterality_embed[
0][0]']
flatten_650 (Flatten) (None, 2) 0 ['oncotree_code_embed[0][0]']
flatten_651 (Flatten) (None, 1) 0 ['pr_status_embed[0][0]']
flatten_652 (Flatten) (None, 2) 0 ['3-gene_classifier_subtype_embed
[0][0]']
embed_concat (Concatenate) (None, 34) 0 ['tf.cast_396[0][0]',
'tf.cast_397[0][0]',
'tf.cast_398[0][0]',
'tf.cast_399[0][0]',
'tf.cast_400[0][0]',
'tf.cast_401[0][0]',
'tf.cast_402[0][0]',
'tf.cast_403[0][0]',
'tf.cast_404[0][0]',
'tf.cast_405[0][0]',
'tf.cast_406[0][0]',
'flatten_638[0][0]',
'flatten_639[0][0]',
'flatten_640[0][0]',
'flatten_641[0][0]',
'flatten_642[0][0]',
'flatten_643[0][0]',
'flatten_644[0][0]',
'flatten_645[0][0]',
'flatten_646[0][0]',
'flatten_647[0][0]',
'flatten_648[0][0]',
'flatten_649[0][0]',
'flatten_650[0][0]',
'flatten_651[0][0]',
'flatten_652[0][0]']
her2_status_measured_by_snp6_X (None, 1) 0 ['string_categorical_445_preproce
_her2_status (HashedCrossing) ssor[0][0]',
'string_categorical_446_preproce
ssor[0][0]']
3-gene_classifier_subtype_X_in (None, 1) 0 ['string_categorical_453_preproce
tegrative_cluster_X_pam50_plus ssor[0][0]',
_claudin-low_subtype (HashedCr 'string_categorical_449_preproce
ossing) ssor[0][0]',
'string_categorical_442_preproce
ssor[0][0]']
er_status_X_er_status_measured (None, 1) 0 ['string_categorical_444_preproce
_by_ihc (HashedCrossing) ssor[0][0]',
'string_categorical_443_preproce
ssor[0][0]']
deep1 (Dense) (None, 34) 1190 ['embed_concat[0][0]']
her2_status_measured_by_snp6_X (None, 1, 2) 16 ['her2_status_measured_by_snp6_X_
_her2_status_embed (Embedding) her2_status[0][0]']
3-gene_classifier_subtype_X_in (None, 1, 17) 5236 ['3-gene_classifier_subtype_X_int
tegrative_cluster_X_pam50_plus egrative_cluster_X_pam50_plus_cla
_claudin-low_subtype_embed (Em udin-low_subtype[0][0]']
bedding)
er_status_X_er_status_measured (None, 1, 2) 8 ['er_status_X_er_status_measured_
_by_ihc_embed (Embedding) by_ihc[0][0]']
deep2 (Dense) (None, 17) 595 ['deep1[0][0]']
flatten_635 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_X_
her2_status_embed[0][0]']
flatten_636 (Flatten) (None, 17) 0 ['3-gene_classifier_subtype_X_int
egrative_cluster_X_pam50_plus_cla
udin-low_subtype_embed[0][0]']
flatten_637 (Flatten) (None, 2) 0 ['er_status_X_er_status_measured_
by_ihc_embed[0][0]']
deep3 (Dense) (None, 10) 180 ['deep2[0][0]']
wide_concat (Concatenate) (None, 21) 0 ['flatten_635[0][0]',
'flatten_636[0][0]',
'flatten_637[0][0]']
concat_deep_wide (Concatenate) (None, 31) 0 ['deep3[0][0]',
'wide_concat[0][0]']
combined (Dense) (None, 1) 32 ['concat_deep_wide[0][0]']
==================================================================================================
Total params: 7,404
Trainable params: 7,371
Non-trainable params: 33
__________________________________________________________________________________________________
# train using the already processed features
history_2 = training_model_2.fit(
ds_train, epochs=15, validation_data=ds_test, verbose=2
)
Epoch 1/15 19/19 - 4s - loss: 0.6784 - acc: 0.5670 - f1_m: 0.3342 - precision_m: 0.8136 - recall_m: 0.2542 - val_loss: 0.6557 - val_acc: 0.6667 - val_f1_m: 0.5209 - val_precision_m: 0.9167 - val_recall_m: 0.3925 - 4s/epoch - 231ms/step Epoch 2/15 19/19 - 0s - loss: 0.6421 - acc: 0.7109 - f1_m: 0.5894 - precision_m: 0.8316 - recall_m: 0.4842 - val_loss: 0.6249 - val_acc: 0.7593 - val_f1_m: 0.7175 - val_precision_m: 0.8542 - val_recall_m: 0.6650 - 76ms/epoch - 4ms/step Epoch 3/15 19/19 - 0s - loss: 0.6029 - acc: 0.7481 - f1_m: 0.7024 - precision_m: 0.8116 - recall_m: 0.6567 - val_loss: 0.5847 - val_acc: 0.7654 - val_f1_m: 0.7393 - val_precision_m: 0.8210 - val_recall_m: 0.7158 - 77ms/epoch - 4ms/step Epoch 4/15 19/19 - 0s - loss: 0.5591 - acc: 0.7630 - f1_m: 0.7296 - precision_m: 0.7956 - recall_m: 0.7142 - val_loss: 0.5429 - val_acc: 0.7840 - val_f1_m: 0.7697 - val_precision_m: 0.8240 - val_recall_m: 0.7667 - 77ms/epoch - 4ms/step Epoch 5/15 19/19 - 0s - loss: 0.5192 - acc: 0.7742 - f1_m: 0.7510 - precision_m: 0.7916 - recall_m: 0.7539 - val_loss: 0.5037 - val_acc: 0.8025 - val_f1_m: 0.7893 - val_precision_m: 0.8267 - val_recall_m: 0.7992 - 72ms/epoch - 4ms/step Epoch 6/15 19/19 - 0s - loss: 0.4864 - acc: 0.7866 - f1_m: 0.7675 - precision_m: 0.7946 - recall_m: 0.7805 - val_loss: 0.4720 - val_acc: 0.8148 - val_f1_m: 0.8076 - val_precision_m: 0.8264 - val_recall_m: 0.8300 - 70ms/epoch - 4ms/step Epoch 7/15 19/19 - 0s - loss: 0.4617 - acc: 0.7953 - f1_m: 0.7789 - precision_m: 0.7933 - recall_m: 0.8043 - val_loss: 0.4499 - val_acc: 0.8395 - val_f1_m: 0.8327 - val_precision_m: 0.8368 - val_recall_m: 0.8600 - 73ms/epoch - 4ms/step Epoch 8/15 19/19 - 0s - loss: 0.4428 - acc: 0.8052 - f1_m: 0.7904 - precision_m: 0.7979 - recall_m: 0.8174 - val_loss: 0.4375 - val_acc: 0.8272 - val_f1_m: 0.8202 - val_precision_m: 0.8283 - val_recall_m: 0.8500 - 71ms/epoch - 4ms/step Epoch 9/15 19/19 - 0s - loss: 0.4278 - acc: 0.8040 - f1_m: 0.7898 - precision_m: 0.7956 - recall_m: 0.8190 - val_loss: 0.4319 - val_acc: 0.8333 - val_f1_m: 0.8270 - val_precision_m: 0.8297 - val_recall_m: 0.8600 - 71ms/epoch - 4ms/step Epoch 10/15 19/19 - 0s - loss: 0.4159 - acc: 0.8052 - f1_m: 0.7918 - precision_m: 0.7971 - recall_m: 0.8228 - val_loss: 0.4308 - val_acc: 0.8272 - val_f1_m: 0.8215 - val_precision_m: 0.8292 - val_recall_m: 0.8500 - 73ms/epoch - 4ms/step Epoch 11/15 19/19 - 0s - loss: 0.4055 - acc: 0.8151 - f1_m: 0.8033 - precision_m: 0.8021 - recall_m: 0.8333 - val_loss: 0.4319 - val_acc: 0.8272 - val_f1_m: 0.8232 - val_precision_m: 0.8302 - val_recall_m: 0.8525 - 73ms/epoch - 4ms/step Epoch 12/15 19/19 - 0s - loss: 0.3966 - acc: 0.8176 - f1_m: 0.8045 - precision_m: 0.8036 - recall_m: 0.8320 - val_loss: 0.4338 - val_acc: 0.8272 - val_f1_m: 0.8232 - val_precision_m: 0.8302 - val_recall_m: 0.8525 - 72ms/epoch - 4ms/step Epoch 13/15 19/19 - 0s - loss: 0.3889 - acc: 0.8263 - f1_m: 0.8126 - precision_m: 0.8128 - recall_m: 0.8363 - val_loss: 0.4359 - val_acc: 0.8210 - val_f1_m: 0.8164 - val_precision_m: 0.8289 - val_recall_m: 0.8425 - 70ms/epoch - 4ms/step Epoch 14/15 19/19 - 0s - loss: 0.3817 - acc: 0.8313 - f1_m: 0.8194 - precision_m: 0.8159 - recall_m: 0.8454 - val_loss: 0.4382 - val_acc: 0.8148 - val_f1_m: 0.8110 - val_precision_m: 0.8224 - val_recall_m: 0.8425 - 72ms/epoch - 4ms/step Epoch 15/15 19/19 - 0s - loss: 0.3750 - acc: 0.8313 - f1_m: 0.8196 - precision_m: 0.8159 - recall_m: 0.8456 - val_loss: 0.4403 - val_acc: 0.8148 - val_f1_m: 0.8110 - val_precision_m: 0.8224 - val_recall_m: 0.8425 - 71ms/epoch - 4ms/step
from matplotlib import pyplot as plt
%matplotlib inline
plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_2.history['f1_m'])
plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_2.history['val_f1_m'])
plt.title('Validation')
plt.subplot(2,2,3)
plt.plot(history_2.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')
plt.subplot(2,2,4)
plt.plot(history_2.history['val_loss'])
plt.xlabel('epochs')
Text(0.5, 0, 'epochs')
# Vizualize some metrics associated with this model
# Source: Modified from in-class lecture
# now lets see how well the model performed
yhat_proba_2 = training_model_2.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions
yhat_2 = np.round(yhat_proba_2.squeeze()) # round to get binary class
conf_mat_2 = mt.confusion_matrix(y_test, yhat_2)
print(conf_mat_2)
print(mt.classification_report(y_test,yhat_2))
# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309. VitalBook file.
# Create pandas dataframe
conf_df_2 = pd.DataFrame(conf_mat_2, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())
# Create heatmap
sns.heatmap(conf_df_2, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 1s 3ms/step
[[65 15]
[15 67]]
precision recall f1-score support
0 0.81 0.81 0.81 80
1 0.82 0.82 0.82 82
accuracy 0.81 162
macro avg 0.81 0.81 0.81 162
weighted avg 0.81 0.81 0.81 162
With a little trial and error, I adjusted the epoch count down to 15 compared to 25 with model 1. On most runs I see overtraining start to occur after 9-15 epochs with my training loss on my validation data starting to trend upward. Due to this, I can say generally I'm seeing this model converge in fewer epochs. Also, I have an F1 score just slightly better than in model 1. My confusion matrix is roughly the same.
Model 3 of 3¶
Here, I'll go back to my original cross-categorical features (feature_space_1) and try changing my optimization method to see what effect that has on the results. For this network I'll use RMSProp instead of ADAM.
# from keras.metrics import Precision, Recall
dict_inputs = feature_space_1.get_inputs() # need to use unprocessed features here, to gain access to each output
# we need to create separate lists for each branch
crossed_outputs = []
# for each crossed variable, make an embedding
for col in feature_space_1.crossers.keys():
x = setup_embedding_from_crossing(feature_space_1, col)
# save these outputs in list to concatenate later
crossed_outputs.append(x)
# now concatenate the outputs and add a fully connected layer
wide_branch = Concatenate(name='wide_concat')(crossed_outputs)
# reset this input branch
all_deep_branch_outputs = []
# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
x = feature_space_1.preprocessors[col].output
x = tf.cast(x,float) # cast an integer as a float here
all_deep_branch_outputs.append(x)
# for each categorical variable
for col in categorical_headers:
# get the output tensor from ebedding layer
x = setup_embedding_from_categorical(feature_space_1, col)
# save these outputs in list to concatenate later
all_deep_branch_outputs.append(x)
# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=10,activation='relu', name='deep3')(deep_branch)
# merge the deep and wide branch
final_branch = Concatenate(name='concat_deep_wide')([deep_branch, wide_branch])
final_branch = Dense(units=1,activation='sigmoid',
name='combined')(final_branch)
training_model_3 = keras.Model(inputs=dict_inputs, outputs=final_branch)
# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_3.compile(
optimizer="RMSProp", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)
training_model_3.summary()
plot_model(
training_model_3, to_file='model.png', show_shapes=True, show_layer_names=True,
rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_37"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
type_of_breast_surgery (InputL [(None, 1)] 0 []
ayer)
cancer_type_detailed (InputLay [(None, 1)] 0 []
er)
cellularity (InputLayer) [(None, 1)] 0 []
pam50_plus_claudin-low_subtype [(None, 1)] 0 []
(InputLayer)
er_status_measured_by_ihc (Inp [(None, 1)] 0 []
utLayer)
er_status (InputLayer) [(None, 1)] 0 []
her2_status_measured_by_snp6 ( [(None, 1)] 0 []
InputLayer)
her2_status (InputLayer) [(None, 1)] 0 []
tumor_other_histologic_subtype [(None, 1)] 0 []
(InputLayer)
inferred_menopausal_state (Inp [(None, 1)] 0 []
utLayer)
integrative_cluster (InputLaye [(None, 1)] 0 []
r)
primary_tumor_laterality (Inpu [(None, 1)] 0 []
tLayer)
oncotree_code (InputLayer) [(None, 1)] 0 []
pr_status (InputLayer) [(None, 1)] 0 []
3-gene_classifier_subtype (Inp [(None, 1)] 0 []
utLayer)
age_at_diagnosis (InputLayer) [(None, 1)] 0 []
neoplasm_histologic_grade (Inp [(None, 1)] 0 []
utLayer)
lymph_nodes_examined_positive [(None, 1)] 0 []
(InputLayer)
mutation_count (InputLayer) [(None, 1)] 0 []
nottingham_prognostic_index (I [(None, 1)] 0 []
nputLayer)
overall_survival_months (Input [(None, 1)] 0 []
Layer)
tumor_size (InputLayer) [(None, 1)] 0 []
tumor_stage (InputLayer) [(None, 1)] 0 []
chemotherapy (InputLayer) [(None, 1)] 0 []
radio_therapy (InputLayer) [(None, 1)] 0 []
hormone_therapy (InputLayer) [(None, 1)] 0 []
string_categorical_424_preproc (None, 1) 0 ['type_of_breast_surgery[0][0]']
essor (StringLookup)
string_categorical_425_preproc (None, 1) 0 ['cancer_type_detailed[0][0]']
essor (StringLookup)
string_categorical_426_preproc (None, 1) 0 ['cellularity[0][0]']
essor (StringLookup)
string_categorical_427_preproc (None, 1) 0 ['pam50_plus_claudin-low_subtype[
essor (StringLookup) 0][0]']
string_categorical_428_preproc (None, 1) 0 ['er_status_measured_by_ihc[0][0]
essor (StringLookup) ']
string_categorical_429_preproc (None, 1) 0 ['er_status[0][0]']
essor (StringLookup)
string_categorical_430_preproc (None, 1) 0 ['her2_status_measured_by_snp6[0]
essor (StringLookup) [0]']
string_categorical_431_preproc (None, 1) 0 ['her2_status[0][0]']
essor (StringLookup)
string_categorical_432_preproc (None, 1) 0 ['tumor_other_histologic_subtype[
essor (StringLookup) 0][0]']
string_categorical_433_preproc (None, 1) 0 ['inferred_menopausal_state[0][0]
essor (StringLookup) ']
string_categorical_434_preproc (None, 1) 0 ['integrative_cluster[0][0]']
essor (StringLookup)
string_categorical_435_preproc (None, 1) 0 ['primary_tumor_laterality[0][0]'
essor (StringLookup) ]
string_categorical_436_preproc (None, 1) 0 ['oncotree_code[0][0]']
essor (StringLookup)
string_categorical_437_preproc (None, 1) 0 ['pr_status[0][0]']
essor (StringLookup)
string_categorical_438_preproc (None, 1) 0 ['3-gene_classifier_subtype[0][0]
essor (StringLookup) ']
float_normalized_306_preproces (None, 1) 3 ['age_at_diagnosis[0][0]']
sor (Normalization)
float_normalized_307_preproces (None, 1) 3 ['neoplasm_histologic_grade[0][0]
sor (Normalization) ']
float_normalized_308_preproces (None, 1) 3 ['lymph_nodes_examined_positive[0
sor (Normalization) ][0]']
float_normalized_309_preproces (None, 1) 3 ['mutation_count[0][0]']
sor (Normalization)
float_normalized_310_preproces (None, 1) 3 ['nottingham_prognostic_index[0][
sor (Normalization) 0]']
float_normalized_311_preproces (None, 1) 3 ['overall_survival_months[0][0]']
sor (Normalization)
float_normalized_312_preproces (None, 1) 3 ['tumor_size[0][0]']
sor (Normalization)
float_normalized_313_preproces (None, 1) 3 ['tumor_stage[0][0]']
sor (Normalization)
float_normalized_303_preproces (None, 1) 3 ['chemotherapy[0][0]']
sor (Normalization)
float_normalized_305_preproces (None, 1) 3 ['radio_therapy[0][0]']
sor (Normalization)
float_normalized_304_preproces (None, 1) 3 ['hormone_therapy[0][0]']
sor (Normalization)
type_of_breast_surgery_embed ( (None, 1, 1) 2 ['string_categorical_424_preproce
Embedding) ssor[0][0]']
cancer_type_detailed_embed (Em (None, 1, 2) 10 ['string_categorical_425_preproce
bedding) ssor[0][0]']
cellularity_embed (Embedding) (None, 1, 1) 3 ['string_categorical_426_preproce
ssor[0][0]']
pam50_plus_claudin-low_subtype (None, 1, 2) 14 ['string_categorical_427_preproce
_embed (Embedding) ssor[0][0]']
er_status_measured_by_ihc_embe (None, 1, 1) 2 ['string_categorical_428_preproce
d (Embedding) ssor[0][0]']
er_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_429_preproce
ssor[0][0]']
her2_status_measured_by_snp6_e (None, 1, 2) 8 ['string_categorical_430_preproce
mbed (Embedding) ssor[0][0]']
her2_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_431_preproce
ssor[0][0]']
tumor_other_histologic_subtype (None, 1, 2) 14 ['string_categorical_432_preproce
_embed (Embedding) ssor[0][0]']
inferred_menopausal_state_embe (None, 1, 1) 2 ['string_categorical_433_preproce
d (Embedding) ssor[0][0]']
integrative_cluster_embed (Emb (None, 1, 3) 33 ['string_categorical_434_preproce
edding) ssor[0][0]']
primary_tumor_laterality_embed (None, 1, 1) 2 ['string_categorical_435_preproce
(Embedding) ssor[0][0]']
oncotree_code_embed (Embedding (None, 1, 2) 10 ['string_categorical_436_preproce
) ssor[0][0]']
pr_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_437_preproce
ssor[0][0]']
3-gene_classifier_subtype_embe (None, 1, 2) 8 ['string_categorical_438_preproce
d (Embedding) ssor[0][0]']
tf.cast_407 (TFOpLambda) (None, 1) 0 ['float_normalized_306_preprocess
or[0][0]']
tf.cast_408 (TFOpLambda) (None, 1) 0 ['float_normalized_307_preprocess
or[0][0]']
tf.cast_409 (TFOpLambda) (None, 1) 0 ['float_normalized_308_preprocess
or[0][0]']
tf.cast_410 (TFOpLambda) (None, 1) 0 ['float_normalized_309_preprocess
or[0][0]']
tf.cast_411 (TFOpLambda) (None, 1) 0 ['float_normalized_310_preprocess
or[0][0]']
tf.cast_412 (TFOpLambda) (None, 1) 0 ['float_normalized_311_preprocess
or[0][0]']
tf.cast_413 (TFOpLambda) (None, 1) 0 ['float_normalized_312_preprocess
or[0][0]']
tf.cast_414 (TFOpLambda) (None, 1) 0 ['float_normalized_313_preprocess
or[0][0]']
tf.cast_415 (TFOpLambda) (None, 1) 0 ['float_normalized_303_preprocess
or[0][0]']
tf.cast_416 (TFOpLambda) (None, 1) 0 ['float_normalized_305_preprocess
or[0][0]']
tf.cast_417 (TFOpLambda) (None, 1) 0 ['float_normalized_304_preprocess
or[0][0]']
flatten_656 (Flatten) (None, 1) 0 ['type_of_breast_surgery_embed[0]
[0]']
flatten_657 (Flatten) (None, 2) 0 ['cancer_type_detailed_embed[0][0
]']
flatten_658 (Flatten) (None, 1) 0 ['cellularity_embed[0][0]']
flatten_659 (Flatten) (None, 2) 0 ['pam50_plus_claudin-low_subtype_
embed[0][0]']
flatten_660 (Flatten) (None, 1) 0 ['er_status_measured_by_ihc_embed
[0][0]']
flatten_661 (Flatten) (None, 1) 0 ['er_status_embed[0][0]']
flatten_662 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_em
bed[0][0]']
flatten_663 (Flatten) (None, 1) 0 ['her2_status_embed[0][0]']
flatten_664 (Flatten) (None, 2) 0 ['tumor_other_histologic_subtype_
embed[0][0]']
flatten_665 (Flatten) (None, 1) 0 ['inferred_menopausal_state_embed
[0][0]']
flatten_666 (Flatten) (None, 3) 0 ['integrative_cluster_embed[0][0]
']
flatten_667 (Flatten) (None, 1) 0 ['primary_tumor_laterality_embed[
0][0]']
flatten_668 (Flatten) (None, 2) 0 ['oncotree_code_embed[0][0]']
flatten_669 (Flatten) (None, 1) 0 ['pr_status_embed[0][0]']
flatten_670 (Flatten) (None, 2) 0 ['3-gene_classifier_subtype_embed
[0][0]']
embed_concat (Concatenate) (None, 34) 0 ['tf.cast_407[0][0]',
'tf.cast_408[0][0]',
'tf.cast_409[0][0]',
'tf.cast_410[0][0]',
'tf.cast_411[0][0]',
'tf.cast_412[0][0]',
'tf.cast_413[0][0]',
'tf.cast_414[0][0]',
'tf.cast_415[0][0]',
'tf.cast_416[0][0]',
'tf.cast_417[0][0]',
'flatten_656[0][0]',
'flatten_657[0][0]',
'flatten_658[0][0]',
'flatten_659[0][0]',
'flatten_660[0][0]',
'flatten_661[0][0]',
'flatten_662[0][0]',
'flatten_663[0][0]',
'flatten_664[0][0]',
'flatten_665[0][0]',
'flatten_666[0][0]',
'flatten_667[0][0]',
'flatten_668[0][0]',
'flatten_669[0][0]',
'flatten_670[0][0]']
her2_status_measured_by_snp6_X (None, 1) 0 ['string_categorical_430_preproce
_her2_status (HashedCrossing) ssor[0][0]',
'string_categorical_431_preproce
ssor[0][0]']
3-gene_classifier_subtype_X_in (None, 1) 0 ['string_categorical_438_preproce
tegrative_cluster (HashedCross ssor[0][0]',
ing) 'string_categorical_434_preproce
ssor[0][0]']
er_status_X_er_status_measured (None, 1) 0 ['string_categorical_429_preproce
_by_ihc (HashedCrossing) ssor[0][0]',
'string_categorical_428_preproce
ssor[0][0]']
deep1 (Dense) (None, 34) 1190 ['embed_concat[0][0]']
her2_status_measured_by_snp6_X (None, 1, 2) 16 ['her2_status_measured_by_snp6_X_
_her2_status_embed (Embedding) her2_status[0][0]']
3-gene_classifier_subtype_X_in (None, 1, 6) 264 ['3-gene_classifier_subtype_X_int
tegrative_cluster_embed (Embed egrative_cluster[0][0]']
ding)
er_status_X_er_status_measured (None, 1, 2) 8 ['er_status_X_er_status_measured_
_by_ihc_embed (Embedding) by_ihc[0][0]']
deep2 (Dense) (None, 17) 595 ['deep1[0][0]']
flatten_653 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_X_
her2_status_embed[0][0]']
flatten_654 (Flatten) (None, 6) 0 ['3-gene_classifier_subtype_X_int
egrative_cluster_embed[0][0]']
flatten_655 (Flatten) (None, 2) 0 ['er_status_X_er_status_measured_
by_ihc_embed[0][0]']
deep3 (Dense) (None, 10) 180 ['deep2[0][0]']
wide_concat (Concatenate) (None, 10) 0 ['flatten_653[0][0]',
'flatten_654[0][0]',
'flatten_655[0][0]']
concat_deep_wide (Concatenate) (None, 20) 0 ['deep3[0][0]',
'wide_concat[0][0]']
combined (Dense) (None, 1) 21 ['concat_deep_wide[0][0]']
==================================================================================================
Total params: 2,421
Trainable params: 2,388
Non-trainable params: 33
__________________________________________________________________________________________________
# train using the already processed features
history_3 = training_model_3.fit(
ds_train, epochs=20, validation_data=ds_test, verbose=2
)
Epoch 1/20 19/19 - 4s - loss: 0.6477 - acc: 0.6365 - f1_m: 0.4069 - precision_m: 0.8202 - recall_m: 0.2962 - val_loss: 0.6280 - val_acc: 0.7222 - val_f1_m: 0.6945 - val_precision_m: 0.7586 - val_recall_m: 0.6725 - 4s/epoch - 194ms/step Epoch 2/20 19/19 - 0s - loss: 0.5937 - acc: 0.7320 - f1_m: 0.6863 - precision_m: 0.7916 - recall_m: 0.6527 - val_loss: 0.5864 - val_acc: 0.7346 - val_f1_m: 0.7364 - val_precision_m: 0.7493 - val_recall_m: 0.7842 - 75ms/epoch - 4ms/step Epoch 3/20 19/19 - 0s - loss: 0.5551 - acc: 0.7680 - f1_m: 0.7347 - precision_m: 0.8019 - recall_m: 0.7231 - val_loss: 0.5488 - val_acc: 0.7407 - val_f1_m: 0.7507 - val_precision_m: 0.7373 - val_recall_m: 0.8142 - 71ms/epoch - 4ms/step Epoch 4/20 19/19 - 0s - loss: 0.5223 - acc: 0.7742 - f1_m: 0.7449 - precision_m: 0.7896 - recall_m: 0.7459 - val_loss: 0.5182 - val_acc: 0.7593 - val_f1_m: 0.7727 - val_precision_m: 0.7463 - val_recall_m: 0.8550 - 71ms/epoch - 4ms/step Epoch 5/20 19/19 - 0s - loss: 0.4965 - acc: 0.7841 - f1_m: 0.7599 - precision_m: 0.7909 - recall_m: 0.7688 - val_loss: 0.4942 - val_acc: 0.7716 - val_f1_m: 0.7834 - val_precision_m: 0.7577 - val_recall_m: 0.8650 - 72ms/epoch - 4ms/step Epoch 6/20 19/19 - 0s - loss: 0.4768 - acc: 0.7953 - f1_m: 0.7732 - precision_m: 0.7942 - recall_m: 0.7877 - val_loss: 0.4766 - val_acc: 0.7778 - val_f1_m: 0.7873 - val_precision_m: 0.7669 - val_recall_m: 0.8650 - 76ms/epoch - 4ms/step Epoch 7/20 19/19 - 0s - loss: 0.4616 - acc: 0.8002 - f1_m: 0.7782 - precision_m: 0.7968 - recall_m: 0.7952 - val_loss: 0.4642 - val_acc: 0.7901 - val_f1_m: 0.7981 - val_precision_m: 0.7733 - val_recall_m: 0.8750 - 71ms/epoch - 4ms/step Epoch 8/20 19/19 - 0s - loss: 0.4509 - acc: 0.8027 - f1_m: 0.7833 - precision_m: 0.7959 - recall_m: 0.8088 - val_loss: 0.4563 - val_acc: 0.7840 - val_f1_m: 0.7913 - val_precision_m: 0.7677 - val_recall_m: 0.8650 - 71ms/epoch - 4ms/step Epoch 9/20 19/19 - 0s - loss: 0.4425 - acc: 0.8065 - f1_m: 0.7859 - precision_m: 0.7980 - recall_m: 0.8089 - val_loss: 0.4517 - val_acc: 0.7840 - val_f1_m: 0.7913 - val_precision_m: 0.7677 - val_recall_m: 0.8650 - 72ms/epoch - 4ms/step Epoch 10/20 19/19 - 0s - loss: 0.4360 - acc: 0.8052 - f1_m: 0.7848 - precision_m: 0.7956 - recall_m: 0.8089 - val_loss: 0.4487 - val_acc: 0.7901 - val_f1_m: 0.7958 - val_precision_m: 0.7727 - val_recall_m: 0.8650 - 70ms/epoch - 4ms/step Epoch 11/20 19/19 - 0s - loss: 0.4301 - acc: 0.8089 - f1_m: 0.7915 - precision_m: 0.7989 - recall_m: 0.8164 - val_loss: 0.4471 - val_acc: 0.7963 - val_f1_m: 0.8000 - val_precision_m: 0.7825 - val_recall_m: 0.8650 - 68ms/epoch - 4ms/step Epoch 12/20 19/19 - 0s - loss: 0.4254 - acc: 0.8089 - f1_m: 0.7922 - precision_m: 0.8020 - recall_m: 0.8164 - val_loss: 0.4464 - val_acc: 0.7963 - val_f1_m: 0.8000 - val_precision_m: 0.7825 - val_recall_m: 0.8650 - 69ms/epoch - 4ms/step Epoch 13/20 19/19 - 0s - loss: 0.4208 - acc: 0.8127 - f1_m: 0.7960 - precision_m: 0.8053 - recall_m: 0.8192 - val_loss: 0.4459 - val_acc: 0.7963 - val_f1_m: 0.8000 - val_precision_m: 0.7825 - val_recall_m: 0.8650 - 67ms/epoch - 4ms/step Epoch 14/20 19/19 - 0s - loss: 0.4164 - acc: 0.8176 - f1_m: 0.8040 - precision_m: 0.8068 - recall_m: 0.8320 - val_loss: 0.4459 - val_acc: 0.7963 - val_f1_m: 0.8000 - val_precision_m: 0.7825 - val_recall_m: 0.8650 - 69ms/epoch - 4ms/step Epoch 15/20 19/19 - 0s - loss: 0.4124 - acc: 0.8238 - f1_m: 0.8100 - precision_m: 0.8134 - recall_m: 0.8344 - val_loss: 0.4459 - val_acc: 0.7901 - val_f1_m: 0.7930 - val_precision_m: 0.7772 - val_recall_m: 0.8550 - 71ms/epoch - 4ms/step Epoch 16/20 19/19 - 0s - loss: 0.4085 - acc: 0.8263 - f1_m: 0.8136 - precision_m: 0.8211 - recall_m: 0.8357 - val_loss: 0.4460 - val_acc: 0.7963 - val_f1_m: 0.7973 - val_precision_m: 0.7884 - val_recall_m: 0.8550 - 70ms/epoch - 4ms/step Epoch 17/20 19/19 - 0s - loss: 0.4047 - acc: 0.8313 - f1_m: 0.8171 - precision_m: 0.8262 - recall_m: 0.8354 - val_loss: 0.4462 - val_acc: 0.8025 - val_f1_m: 0.8024 - val_precision_m: 0.7943 - val_recall_m: 0.8550 - 72ms/epoch - 4ms/step Epoch 18/20 19/19 - 0s - loss: 0.4011 - acc: 0.8300 - f1_m: 0.8162 - precision_m: 0.8244 - recall_m: 0.8354 - val_loss: 0.4465 - val_acc: 0.8025 - val_f1_m: 0.8024 - val_precision_m: 0.7943 - val_recall_m: 0.8550 - 83ms/epoch - 4ms/step Epoch 19/20 19/19 - 0s - loss: 0.3975 - acc: 0.8325 - f1_m: 0.8198 - precision_m: 0.8237 - recall_m: 0.8427 - val_loss: 0.4467 - val_acc: 0.8025 - val_f1_m: 0.8024 - val_precision_m: 0.7943 - val_recall_m: 0.8550 - 81ms/epoch - 4ms/step Epoch 20/20 19/19 - 0s - loss: 0.3938 - acc: 0.8362 - f1_m: 0.8230 - precision_m: 0.8286 - recall_m: 0.8427 - val_loss: 0.4471 - val_acc: 0.8025 - val_f1_m: 0.8024 - val_precision_m: 0.7943 - val_recall_m: 0.8550 - 73ms/epoch - 4ms/step
from matplotlib import pyplot as plt
%matplotlib inline
plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_3.history['f1_m'])
plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_3.history['val_f1_m'])
plt.title('Validation')
plt.subplot(2,2,3)
plt.plot(history_3.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')
plt.subplot(2,2,4)
plt.plot(history_3.history['val_loss'])
plt.xlabel('epochs')
Text(0.5, 0, 'epochs')
# Vizualize some metrics associated with this model
# Source: Modified from in-class lecture
# Use the sklearn metrics here, if you want to
# from sklearn import metrics as mt
# y_test = tf.concat([y for x, y in ds_test], axis=0)
# y_test = y_test.numpy()
# now lets see how well the model performed
yhat_proba_3 = training_model_3.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions
yhat_3 = np.round(yhat_proba_3.squeeze()) # round to get binary class
conf_mat_3 = mt.confusion_matrix(y_test, yhat_3)
print(conf_mat_3)
print(mt.classification_report(y_test,yhat_3))
# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309. VitalBook file.
# Create pandas dataframe
conf_df_3 = pd.DataFrame(conf_mat_3, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())
# Create heatmap
sns.heatmap(conf_df_3, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 1s 2ms/step
[[62 18]
[14 68]]
precision recall f1-score support
0 0.82 0.78 0.79 80
1 0.79 0.83 0.81 82
accuracy 0.80 162
macro avg 0.80 0.80 0.80 162
weighted avg 0.80 0.80 0.80 162
This third model performed very similarly to the first which I found interesting. I have a very similar F1 score and but my confusion matrix is a little worse. I can't draw any real conclusions about the change in optimizer. After running all of these models several times I can say that the results jumped around. I had F1 scores from 80 to 84 across all three models which switched from model to model. Any of these would be viable candidates to carry forward to the next evaluation.
Investigating Generalization Performance¶
For this portion of the lab I will use model 2 of 3 from above, which specifically means I'll be using training_model_2 and feature_space_2. Any model would have likely been appropriate and seemed comparable, however during multiple runs, that model appeared to converge the fastest on the validation data.
Per the instructions, I consider model 2 of 3 above to be one of the two required for this portion of the rubric. Therefore I now need to alter model 2 to see if different performance can be achieved. To do this I'll be adding a layer to the deep portion of the network to step it down a little closer to my final binary output layer. I also changed the number of neurons in the last two layers. Finally I altered the number of epochs after a few test runs showed my model taking longer to converge on validation.
# from keras.metrics import Precision, Recall
dict_inputs = feature_space_2.get_inputs() # need to use unprocessed features here, to gain access to each output
# we need to create separate lists for each branch
crossed_outputs = []
# for each crossed variable, make an embedding
for col in feature_space_2.crossers.keys():
x = setup_embedding_from_crossing(feature_space_2, col)
# save these outputs in list to concatenate later
crossed_outputs.append(x)
# now concatenate the outputs and add a fully connected layer
wide_branch = Concatenate(name='wide_concat')(crossed_outputs)
# reset this input branch
all_deep_branch_outputs = []
# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
x = feature_space_2.preprocessors[col].output
x = tf.cast(x,float) # cast an integer as a float here
all_deep_branch_outputs.append(x)
# for each categorical variable
for col in categorical_headers:
# get the output tensor from ebedding layer
x = setup_embedding_from_categorical(feature_space_2, col)
# save these outputs in list to concatenate later
all_deep_branch_outputs.append(x)
# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=8,activation='relu', name='deep3')(deep_branch) # Changed from 10 to 8 neurons
deep_branch = Dense(units=4,activation='relu', name='deep4')(deep_branch) # This is my new layer
# merge the deep and wide branch
final_branch = Concatenate(name='concat_deep_wide')([deep_branch, wide_branch])
final_branch = Dense(units=1,activation='sigmoid',
name='combined')(final_branch)
training_model_4 = keras.Model(inputs=dict_inputs, outputs=final_branch)
# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_4.compile(
optimizer="adam", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)
training_model_4.summary()
plot_model(
training_model_4, to_file='model.png', show_shapes=True, show_layer_names=True,
rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_38"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
type_of_breast_surgery (InputL [(None, 1)] 0 []
ayer)
cancer_type_detailed (InputLay [(None, 1)] 0 []
er)
cellularity (InputLayer) [(None, 1)] 0 []
pam50_plus_claudin-low_subtype [(None, 1)] 0 []
(InputLayer)
er_status_measured_by_ihc (Inp [(None, 1)] 0 []
utLayer)
er_status (InputLayer) [(None, 1)] 0 []
her2_status_measured_by_snp6 ( [(None, 1)] 0 []
InputLayer)
her2_status (InputLayer) [(None, 1)] 0 []
tumor_other_histologic_subtype [(None, 1)] 0 []
(InputLayer)
inferred_menopausal_state (Inp [(None, 1)] 0 []
utLayer)
integrative_cluster (InputLaye [(None, 1)] 0 []
r)
primary_tumor_laterality (Inpu [(None, 1)] 0 []
tLayer)
oncotree_code (InputLayer) [(None, 1)] 0 []
pr_status (InputLayer) [(None, 1)] 0 []
3-gene_classifier_subtype (Inp [(None, 1)] 0 []
utLayer)
age_at_diagnosis (InputLayer) [(None, 1)] 0 []
neoplasm_histologic_grade (Inp [(None, 1)] 0 []
utLayer)
lymph_nodes_examined_positive [(None, 1)] 0 []
(InputLayer)
mutation_count (InputLayer) [(None, 1)] 0 []
nottingham_prognostic_index (I [(None, 1)] 0 []
nputLayer)
overall_survival_months (Input [(None, 1)] 0 []
Layer)
tumor_size (InputLayer) [(None, 1)] 0 []
tumor_stage (InputLayer) [(None, 1)] 0 []
chemotherapy (InputLayer) [(None, 1)] 0 []
radio_therapy (InputLayer) [(None, 1)] 0 []
hormone_therapy (InputLayer) [(None, 1)] 0 []
string_categorical_439_preproc (None, 1) 0 ['type_of_breast_surgery[0][0]']
essor (StringLookup)
string_categorical_440_preproc (None, 1) 0 ['cancer_type_detailed[0][0]']
essor (StringLookup)
string_categorical_441_preproc (None, 1) 0 ['cellularity[0][0]']
essor (StringLookup)
string_categorical_442_preproc (None, 1) 0 ['pam50_plus_claudin-low_subtype[
essor (StringLookup) 0][0]']
string_categorical_443_preproc (None, 1) 0 ['er_status_measured_by_ihc[0][0]
essor (StringLookup) ']
string_categorical_444_preproc (None, 1) 0 ['er_status[0][0]']
essor (StringLookup)
string_categorical_445_preproc (None, 1) 0 ['her2_status_measured_by_snp6[0]
essor (StringLookup) [0]']
string_categorical_446_preproc (None, 1) 0 ['her2_status[0][0]']
essor (StringLookup)
string_categorical_447_preproc (None, 1) 0 ['tumor_other_histologic_subtype[
essor (StringLookup) 0][0]']
string_categorical_448_preproc (None, 1) 0 ['inferred_menopausal_state[0][0]
essor (StringLookup) ']
string_categorical_449_preproc (None, 1) 0 ['integrative_cluster[0][0]']
essor (StringLookup)
string_categorical_450_preproc (None, 1) 0 ['primary_tumor_laterality[0][0]'
essor (StringLookup) ]
string_categorical_451_preproc (None, 1) 0 ['oncotree_code[0][0]']
essor (StringLookup)
string_categorical_452_preproc (None, 1) 0 ['pr_status[0][0]']
essor (StringLookup)
string_categorical_453_preproc (None, 1) 0 ['3-gene_classifier_subtype[0][0]
essor (StringLookup) ']
float_normalized_317_preproces (None, 1) 3 ['age_at_diagnosis[0][0]']
sor (Normalization)
float_normalized_318_preproces (None, 1) 3 ['neoplasm_histologic_grade[0][0]
sor (Normalization) ']
float_normalized_319_preproces (None, 1) 3 ['lymph_nodes_examined_positive[0
sor (Normalization) ][0]']
float_normalized_320_preproces (None, 1) 3 ['mutation_count[0][0]']
sor (Normalization)
float_normalized_321_preproces (None, 1) 3 ['nottingham_prognostic_index[0][
sor (Normalization) 0]']
float_normalized_322_preproces (None, 1) 3 ['overall_survival_months[0][0]']
sor (Normalization)
float_normalized_323_preproces (None, 1) 3 ['tumor_size[0][0]']
sor (Normalization)
float_normalized_324_preproces (None, 1) 3 ['tumor_stage[0][0]']
sor (Normalization)
float_normalized_314_preproces (None, 1) 3 ['chemotherapy[0][0]']
sor (Normalization)
float_normalized_316_preproces (None, 1) 3 ['radio_therapy[0][0]']
sor (Normalization)
float_normalized_315_preproces (None, 1) 3 ['hormone_therapy[0][0]']
sor (Normalization)
type_of_breast_surgery_embed ( (None, 1, 1) 2 ['string_categorical_439_preproce
Embedding) ssor[0][0]']
cancer_type_detailed_embed (Em (None, 1, 2) 10 ['string_categorical_440_preproce
bedding) ssor[0][0]']
cellularity_embed (Embedding) (None, 1, 1) 3 ['string_categorical_441_preproce
ssor[0][0]']
pam50_plus_claudin-low_subtype (None, 1, 2) 14 ['string_categorical_442_preproce
_embed (Embedding) ssor[0][0]']
er_status_measured_by_ihc_embe (None, 1, 1) 2 ['string_categorical_443_preproce
d (Embedding) ssor[0][0]']
er_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_444_preproce
ssor[0][0]']
her2_status_measured_by_snp6_e (None, 1, 2) 8 ['string_categorical_445_preproce
mbed (Embedding) ssor[0][0]']
her2_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_446_preproce
ssor[0][0]']
tumor_other_histologic_subtype (None, 1, 2) 14 ['string_categorical_447_preproce
_embed (Embedding) ssor[0][0]']
inferred_menopausal_state_embe (None, 1, 1) 2 ['string_categorical_448_preproce
d (Embedding) ssor[0][0]']
integrative_cluster_embed (Emb (None, 1, 3) 33 ['string_categorical_449_preproce
edding) ssor[0][0]']
primary_tumor_laterality_embed (None, 1, 1) 2 ['string_categorical_450_preproce
(Embedding) ssor[0][0]']
oncotree_code_embed (Embedding (None, 1, 2) 10 ['string_categorical_451_preproce
) ssor[0][0]']
pr_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_452_preproce
ssor[0][0]']
3-gene_classifier_subtype_embe (None, 1, 2) 8 ['string_categorical_453_preproce
d (Embedding) ssor[0][0]']
tf.cast_418 (TFOpLambda) (None, 1) 0 ['float_normalized_317_preprocess
or[0][0]']
tf.cast_419 (TFOpLambda) (None, 1) 0 ['float_normalized_318_preprocess
or[0][0]']
tf.cast_420 (TFOpLambda) (None, 1) 0 ['float_normalized_319_preprocess
or[0][0]']
tf.cast_421 (TFOpLambda) (None, 1) 0 ['float_normalized_320_preprocess
or[0][0]']
tf.cast_422 (TFOpLambda) (None, 1) 0 ['float_normalized_321_preprocess
or[0][0]']
tf.cast_423 (TFOpLambda) (None, 1) 0 ['float_normalized_322_preprocess
or[0][0]']
tf.cast_424 (TFOpLambda) (None, 1) 0 ['float_normalized_323_preprocess
or[0][0]']
tf.cast_425 (TFOpLambda) (None, 1) 0 ['float_normalized_324_preprocess
or[0][0]']
tf.cast_426 (TFOpLambda) (None, 1) 0 ['float_normalized_314_preprocess
or[0][0]']
tf.cast_427 (TFOpLambda) (None, 1) 0 ['float_normalized_316_preprocess
or[0][0]']
tf.cast_428 (TFOpLambda) (None, 1) 0 ['float_normalized_315_preprocess
or[0][0]']
flatten_674 (Flatten) (None, 1) 0 ['type_of_breast_surgery_embed[0]
[0]']
flatten_675 (Flatten) (None, 2) 0 ['cancer_type_detailed_embed[0][0
]']
flatten_676 (Flatten) (None, 1) 0 ['cellularity_embed[0][0]']
flatten_677 (Flatten) (None, 2) 0 ['pam50_plus_claudin-low_subtype_
embed[0][0]']
flatten_678 (Flatten) (None, 1) 0 ['er_status_measured_by_ihc_embed
[0][0]']
flatten_679 (Flatten) (None, 1) 0 ['er_status_embed[0][0]']
flatten_680 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_em
bed[0][0]']
flatten_681 (Flatten) (None, 1) 0 ['her2_status_embed[0][0]']
flatten_682 (Flatten) (None, 2) 0 ['tumor_other_histologic_subtype_
embed[0][0]']
flatten_683 (Flatten) (None, 1) 0 ['inferred_menopausal_state_embed
[0][0]']
flatten_684 (Flatten) (None, 3) 0 ['integrative_cluster_embed[0][0]
']
flatten_685 (Flatten) (None, 1) 0 ['primary_tumor_laterality_embed[
0][0]']
flatten_686 (Flatten) (None, 2) 0 ['oncotree_code_embed[0][0]']
flatten_687 (Flatten) (None, 1) 0 ['pr_status_embed[0][0]']
flatten_688 (Flatten) (None, 2) 0 ['3-gene_classifier_subtype_embed
[0][0]']
embed_concat (Concatenate) (None, 34) 0 ['tf.cast_418[0][0]',
'tf.cast_419[0][0]',
'tf.cast_420[0][0]',
'tf.cast_421[0][0]',
'tf.cast_422[0][0]',
'tf.cast_423[0][0]',
'tf.cast_424[0][0]',
'tf.cast_425[0][0]',
'tf.cast_426[0][0]',
'tf.cast_427[0][0]',
'tf.cast_428[0][0]',
'flatten_674[0][0]',
'flatten_675[0][0]',
'flatten_676[0][0]',
'flatten_677[0][0]',
'flatten_678[0][0]',
'flatten_679[0][0]',
'flatten_680[0][0]',
'flatten_681[0][0]',
'flatten_682[0][0]',
'flatten_683[0][0]',
'flatten_684[0][0]',
'flatten_685[0][0]',
'flatten_686[0][0]',
'flatten_687[0][0]',
'flatten_688[0][0]']
deep1 (Dense) (None, 34) 1190 ['embed_concat[0][0]']
her2_status_measured_by_snp6_X (None, 1) 0 ['string_categorical_445_preproce
_her2_status (HashedCrossing) ssor[0][0]',
'string_categorical_446_preproce
ssor[0][0]']
3-gene_classifier_subtype_X_in (None, 1) 0 ['string_categorical_453_preproce
tegrative_cluster_X_pam50_plus ssor[0][0]',
_claudin-low_subtype (HashedCr 'string_categorical_449_preproce
ossing) ssor[0][0]',
'string_categorical_442_preproce
ssor[0][0]']
er_status_X_er_status_measured (None, 1) 0 ['string_categorical_444_preproce
_by_ihc (HashedCrossing) ssor[0][0]',
'string_categorical_443_preproce
ssor[0][0]']
deep2 (Dense) (None, 17) 595 ['deep1[0][0]']
her2_status_measured_by_snp6_X (None, 1, 2) 16 ['her2_status_measured_by_snp6_X_
_her2_status_embed (Embedding) her2_status[0][0]']
3-gene_classifier_subtype_X_in (None, 1, 17) 5236 ['3-gene_classifier_subtype_X_int
tegrative_cluster_X_pam50_plus egrative_cluster_X_pam50_plus_cla
_claudin-low_subtype_embed (Em udin-low_subtype[0][0]']
bedding)
er_status_X_er_status_measured (None, 1, 2) 8 ['er_status_X_er_status_measured_
_by_ihc_embed (Embedding) by_ihc[0][0]']
deep3 (Dense) (None, 8) 144 ['deep2[0][0]']
flatten_671 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_X_
her2_status_embed[0][0]']
flatten_672 (Flatten) (None, 17) 0 ['3-gene_classifier_subtype_X_int
egrative_cluster_X_pam50_plus_cla
udin-low_subtype_embed[0][0]']
flatten_673 (Flatten) (None, 2) 0 ['er_status_X_er_status_measured_
by_ihc_embed[0][0]']
deep4 (Dense) (None, 4) 36 ['deep3[0][0]']
wide_concat (Concatenate) (None, 21) 0 ['flatten_671[0][0]',
'flatten_672[0][0]',
'flatten_673[0][0]']
concat_deep_wide (Concatenate) (None, 25) 0 ['deep4[0][0]',
'wide_concat[0][0]']
combined (Dense) (None, 1) 26 ['concat_deep_wide[0][0]']
==================================================================================================
Total params: 7,398
Trainable params: 7,365
Non-trainable params: 33
__________________________________________________________________________________________________
# train using the already processed features
history_4 = training_model_4.fit(
ds_train, epochs=25, validation_data=ds_test, verbose=2
) # changed number of epochs
Epoch 1/25 19/19 - 4s - loss: 0.6965 - acc: 0.5261 - f1_m: 0.1033 - precision_m: 0.5263 - recall_m: 0.0598 - val_loss: 0.6834 - val_acc: 0.5370 - val_f1_m: 0.1642 - val_precision_m: 0.7500 - val_recall_m: 0.0942 - 4s/epoch - 236ms/step Epoch 2/25 19/19 - 0s - loss: 0.6697 - acc: 0.6154 - f1_m: 0.3665 - precision_m: 0.8351 - recall_m: 0.2450 - val_loss: 0.6695 - val_acc: 0.6296 - val_f1_m: 0.4410 - val_precision_m: 0.9167 - val_recall_m: 0.2992 - 75ms/epoch - 4ms/step Epoch 3/25 19/19 - 0s - loss: 0.6517 - acc: 0.6675 - f1_m: 0.5365 - precision_m: 0.8302 - recall_m: 0.4060 - val_loss: 0.6542 - val_acc: 0.6728 - val_f1_m: 0.5523 - val_precision_m: 0.8750 - val_recall_m: 0.4200 - 76ms/epoch - 4ms/step Epoch 4/25 19/19 - 0s - loss: 0.6305 - acc: 0.7097 - f1_m: 0.6119 - precision_m: 0.8386 - recall_m: 0.4940 - val_loss: 0.6339 - val_acc: 0.7037 - val_f1_m: 0.6318 - val_precision_m: 0.8495 - val_recall_m: 0.5342 - 77ms/epoch - 4ms/step Epoch 5/25 19/19 - 0s - loss: 0.6040 - acc: 0.7357 - f1_m: 0.6734 - precision_m: 0.8362 - recall_m: 0.5748 - val_loss: 0.6104 - val_acc: 0.7346 - val_f1_m: 0.6750 - val_precision_m: 0.8512 - val_recall_m: 0.5917 - 76ms/epoch - 4ms/step Epoch 6/25 19/19 - 0s - loss: 0.5739 - acc: 0.7605 - f1_m: 0.7165 - precision_m: 0.8371 - recall_m: 0.6437 - val_loss: 0.5806 - val_acc: 0.7778 - val_f1_m: 0.7432 - val_precision_m: 0.8598 - val_recall_m: 0.6750 - 77ms/epoch - 4ms/step Epoch 7/25 19/19 - 0s - loss: 0.5397 - acc: 0.7816 - f1_m: 0.7522 - precision_m: 0.8293 - recall_m: 0.7059 - val_loss: 0.5449 - val_acc: 0.7840 - val_f1_m: 0.7619 - val_precision_m: 0.8438 - val_recall_m: 0.7183 - 76ms/epoch - 4ms/step Epoch 8/25 19/19 - 0s - loss: 0.5016 - acc: 0.7965 - f1_m: 0.7764 - precision_m: 0.8147 - recall_m: 0.7641 - val_loss: 0.5091 - val_acc: 0.7901 - val_f1_m: 0.7816 - val_precision_m: 0.8083 - val_recall_m: 0.7892 - 76ms/epoch - 4ms/step Epoch 9/25 19/19 - 0s - loss: 0.4673 - acc: 0.8077 - f1_m: 0.7961 - precision_m: 0.8095 - recall_m: 0.8079 - val_loss: 0.4777 - val_acc: 0.7901 - val_f1_m: 0.7859 - val_precision_m: 0.7884 - val_recall_m: 0.8192 - 76ms/epoch - 4ms/step Epoch 10/25 19/19 - 0s - loss: 0.4413 - acc: 0.8139 - f1_m: 0.8045 - precision_m: 0.8074 - recall_m: 0.8302 - val_loss: 0.4569 - val_acc: 0.7901 - val_f1_m: 0.7859 - val_precision_m: 0.7884 - val_recall_m: 0.8192 - 72ms/epoch - 4ms/step Epoch 11/25 19/19 - 0s - loss: 0.4229 - acc: 0.8201 - f1_m: 0.8111 - precision_m: 0.8131 - recall_m: 0.8386 - val_loss: 0.4456 - val_acc: 0.8025 - val_f1_m: 0.8000 - val_precision_m: 0.8046 - val_recall_m: 0.8400 - 78ms/epoch - 4ms/step Epoch 12/25 19/19 - 0s - loss: 0.4097 - acc: 0.8263 - f1_m: 0.8172 - precision_m: 0.8140 - recall_m: 0.8462 - val_loss: 0.4415 - val_acc: 0.8025 - val_f1_m: 0.8000 - val_precision_m: 0.8046 - val_recall_m: 0.8400 - 79ms/epoch - 4ms/step Epoch 13/25 19/19 - 0s - loss: 0.3971 - acc: 0.8300 - f1_m: 0.8195 - precision_m: 0.8155 - recall_m: 0.8487 - val_loss: 0.4404 - val_acc: 0.8025 - val_f1_m: 0.8000 - val_precision_m: 0.8046 - val_recall_m: 0.8400 - 79ms/epoch - 4ms/step Epoch 14/25 19/19 - 0s - loss: 0.3886 - acc: 0.8325 - f1_m: 0.8231 - precision_m: 0.8169 - recall_m: 0.8555 - val_loss: 0.4404 - val_acc: 0.7963 - val_f1_m: 0.7952 - val_precision_m: 0.7992 - val_recall_m: 0.8400 - 79ms/epoch - 4ms/step Epoch 15/25 19/19 - 0s - loss: 0.3800 - acc: 0.8362 - f1_m: 0.8263 - precision_m: 0.8213 - recall_m: 0.8555 - val_loss: 0.4415 - val_acc: 0.7963 - val_f1_m: 0.7952 - val_precision_m: 0.7992 - val_recall_m: 0.8400 - 80ms/epoch - 4ms/step Epoch 16/25 19/19 - 0s - loss: 0.3733 - acc: 0.8400 - f1_m: 0.8318 - precision_m: 0.8223 - recall_m: 0.8648 - val_loss: 0.4431 - val_acc: 0.7963 - val_f1_m: 0.7948 - val_precision_m: 0.7941 - val_recall_m: 0.8400 - 76ms/epoch - 4ms/step Epoch 17/25 19/19 - 0s - loss: 0.3664 - acc: 0.8437 - f1_m: 0.8351 - precision_m: 0.8261 - recall_m: 0.8665 - val_loss: 0.4448 - val_acc: 0.7963 - val_f1_m: 0.7948 - val_precision_m: 0.7941 - val_recall_m: 0.8400 - 75ms/epoch - 4ms/step Epoch 18/25 19/19 - 0s - loss: 0.3604 - acc: 0.8462 - f1_m: 0.8393 - precision_m: 0.8304 - recall_m: 0.8707 - val_loss: 0.4468 - val_acc: 0.7963 - val_f1_m: 0.7948 - val_precision_m: 0.7941 - val_recall_m: 0.8400 - 78ms/epoch - 4ms/step Epoch 19/25 19/19 - 0s - loss: 0.3545 - acc: 0.8511 - f1_m: 0.8438 - precision_m: 0.8363 - recall_m: 0.8733 - val_loss: 0.4489 - val_acc: 0.8025 - val_f1_m: 0.7999 - val_precision_m: 0.8000 - val_recall_m: 0.8400 - 72ms/epoch - 4ms/step Epoch 20/25 19/19 - 0s - loss: 0.3490 - acc: 0.8524 - f1_m: 0.8448 - precision_m: 0.8380 - recall_m: 0.8733 - val_loss: 0.4511 - val_acc: 0.8025 - val_f1_m: 0.7999 - val_precision_m: 0.8000 - val_recall_m: 0.8400 - 76ms/epoch - 4ms/step Epoch 21/25 19/19 - 0s - loss: 0.3441 - acc: 0.8610 - f1_m: 0.8528 - precision_m: 0.8473 - recall_m: 0.8765 - val_loss: 0.4532 - val_acc: 0.8025 - val_f1_m: 0.7999 - val_precision_m: 0.8000 - val_recall_m: 0.8400 - 77ms/epoch - 4ms/step Epoch 22/25 19/19 - 0s - loss: 0.3386 - acc: 0.8660 - f1_m: 0.8575 - precision_m: 0.8522 - recall_m: 0.8793 - val_loss: 0.4549 - val_acc: 0.7963 - val_f1_m: 0.7931 - val_precision_m: 0.7987 - val_recall_m: 0.8300 - 74ms/epoch - 4ms/step Epoch 23/25 19/19 - 0s - loss: 0.3338 - acc: 0.8697 - f1_m: 0.8610 - precision_m: 0.8565 - recall_m: 0.8819 - val_loss: 0.4561 - val_acc: 0.7963 - val_f1_m: 0.7931 - val_precision_m: 0.7987 - val_recall_m: 0.8300 - 77ms/epoch - 4ms/step Epoch 24/25 19/19 - 0s - loss: 0.3285 - acc: 0.8722 - f1_m: 0.8634 - precision_m: 0.8602 - recall_m: 0.8824 - val_loss: 0.4580 - val_acc: 0.7963 - val_f1_m: 0.7931 - val_precision_m: 0.7987 - val_recall_m: 0.8300 - 73ms/epoch - 4ms/step Epoch 25/25 19/19 - 0s - loss: 0.3237 - acc: 0.8747 - f1_m: 0.8663 - precision_m: 0.8634 - recall_m: 0.8854 - val_loss: 0.4600 - val_acc: 0.7963 - val_f1_m: 0.7910 - val_precision_m: 0.8077 - val_recall_m: 0.8200 - 82ms/epoch - 4ms/step
from matplotlib import pyplot as plt
%matplotlib inline
plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_4.history['f1_m'])
plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_4.history['val_f1_m'])
plt.title('Validation')
plt.subplot(2,2,3)
plt.plot(history_4.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')
plt.subplot(2,2,4)
plt.plot(history_4.history['val_loss'])
plt.xlabel('epochs')
Text(0.5, 0, 'epochs')
Generally, I saw my model convergence all over the map on multiple runs. I changed the number of epochs after a few test runs. Now I'm converging between 10-15 epochs. However I'd like to see my loss getting overall a little lower to be confident I'm getting better performance. I'm not seeing that between the orignal model 2 and this altered version. I do see overtraining occurring in this model as my validation training loss begins to trend back upward after about 12 epochs.
To verify whether my models are really different, I'll peform a statistical analysis of the two.
# Vizualize some metrics associated with this model
# Source: Modified from in-class lecture
# Use the sklearn metrics here, if you want to
# from sklearn import metrics as mt
# y_test = tf.concat([y for x, y in ds_test], axis=0)
# y_test = y_test.numpy()
# now lets see how well the model performed
yhat_proba_4 = training_model_4.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions
yhat_4 = np.round(yhat_proba_4.squeeze()) # round to get binary class
conf_mat_4 = mt.confusion_matrix(y_test, yhat_4)
print(conf_mat_4)
print(mt.classification_report(y_test,yhat_4))
# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309. VitalBook file.
# Create pandas dataframe
conf_df_4 = pd.DataFrame(conf_mat_4, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())
# Create heatmap
sns.heatmap(conf_df_4, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 1s 2ms/step
[[64 16]
[17 65]]
precision recall f1-score support
0 0.79 0.80 0.80 80
1 0.80 0.79 0.80 82
accuracy 0.80 162
macro avg 0.80 0.80 0.80 162
weighted avg 0.80 0.80 0.80 162
Comparison between model 2 and model 4¶
To determine which model truly performed better I need to understand whether the models are really different from one another.
Initially, I split my data using KFold and 6 folds, then used those training and test sets for each model. So the models are all trained on the exact same data. So instead of looking at the f1_scores for each fold, I'm going to compare the f1_scores measured during model fitting.
In another dataset, this could potentially cause me problems as I'm comparing the results of the entire dataset to one another as opposed to each fold. However, because my prediction classes are perfectly balanced, and the data for each model is identical, I believe this will still yield some useful information for comparison.
Note, ideally I'd run a cross_val_score here but I couldn't get it to work with the keras models despite multiple efforts.
from scipy.stats import t
# Get the histories of val_f1 scores from my two models for comparison
f1_score_model_2 = history_2.history['val_f1_m']
f1_score_model_4 = history_4.history['val_f1_m']
# get error rates for both model's f1 scores
model_2_err = [1 - f1 for f1 in f1_score_model_2]
model_4_err = [1 - f1 for f1 in f1_score_model_4]
d = []
for err in range(len(model_2_err)):
d.append(model_2_err[err] - model_4_err[err])
dbar = sum(d) / len(d)
stdtot = np.std(d)
epochs = 12
confidence_level = 0.95
degrees_of_freedom = epochs
# Calculate the critical value, t
t = t.ppf((1 + confidence_level) / 2, degrees_of_freedom)
# print(f'The error of the three models is\n', acc1.mean(), '\n', acc2.mean(), '\n', acc3.mean())
print('Range of:', dbar-t*stdtot,dbar+t*stdtot, 'between model 2 and model 4')
Range of: -0.3139904159770698 0.12504712725344203 between model 2 and model 4
An interesting note about my statistical analysis. I ran these models multiple times. There were times in which the models showed there was no statistical difference (range contained 0), and there were times it showed there was (range did not contain 0). This was with no model changes at all. That doesn't fill me with confidence that I've implemented everything correctly. And gives me even less faith in statistics. Here are a few values from sample runs:
- Range of: -0.33352213888834137 0.20063908685714227 between model 2 and model 4 -- Statistically different
- Range of: -0.086830071054072 0.0883182957193959 between model 2 and model 4 -- Statistically not different
- Range of: -0.08561457731365749 0.06702024875123888 between model 2 and model 4 -- Statistically different
Because of these mixed results I'll speak to these results from two angles. If the two models are NOT statistically different, then any given run may result in values that are better or worse than one another, and therefore either model performs roughly the same. However, if we take the opposite case that they ARE statistically different, then the difference could only be minor as the results overalp so frequently.
Performance Comparision vs MLP (Deep Side)¶
For this comparison I'll be using just the deep side of my wide and deep neural network along with my best performing model. I've mentioned the interchangeability of "best model" several times throughout this notebook, so which model I select does not truly make a large difference here. For the sake of consistency, I'm going to use model 2 as I did for my previous section's comparisons.
# Source: Modified from in-class lecture to match my dataset
from tensorflow.keras.utils import FeatureSpace
feature_space_mlp = FeatureSpace(
features={
# Categorical feature encoded as string
"type_of_breast_surgery": FeatureSpace.string_categorical(num_oov_indices=0),
"cancer_type_detailed": FeatureSpace.string_categorical(num_oov_indices=0),
"cellularity": FeatureSpace.string_categorical(num_oov_indices=0),
"pam50_plus_claudin-low_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
"er_status_measured_by_ihc": FeatureSpace.string_categorical(num_oov_indices=0),
"er_status": FeatureSpace.string_categorical(num_oov_indices=0),
"her2_status_measured_by_snp6": FeatureSpace.string_categorical(num_oov_indices=0),
"her2_status": FeatureSpace.string_categorical(num_oov_indices=0),
"tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
# "tumor_other_histologic_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
"inferred_menopausal_state": FeatureSpace.string_categorical(num_oov_indices=0),
"integrative_cluster": FeatureSpace.string_categorical(num_oov_indices=0),
"primary_tumor_laterality": FeatureSpace.string_categorical(num_oov_indices=0),
"oncotree_code": FeatureSpace.string_categorical(num_oov_indices=0),
"pr_status": FeatureSpace.string_categorical(num_oov_indices=0),
"3-gene_classifier_subtype": FeatureSpace.string_categorical(num_oov_indices=0),
# Numerical features to normalize (normalization will be learned)
# learns the mean, variance, and if to invert
"chemotherapy": FeatureSpace.float_normalized(),
"hormone_therapy": FeatureSpace.float_normalized(),
"radio_therapy": FeatureSpace.float_normalized(),
"age_at_diagnosis": FeatureSpace.float_normalized(),
"neoplasm_histologic_grade": FeatureSpace.float_normalized(),
"lymph_nodes_examined_positive": FeatureSpace.float_normalized(),
"mutation_count": FeatureSpace.float_normalized(),
"nottingham_prognostic_index": FeatureSpace.float_normalized(),
"overall_survival_months": FeatureSpace.float_normalized(),
"tumor_size": FeatureSpace.float_normalized(),
"tumor_stage": FeatureSpace.float_normalized(),
},
output_mode="concat",
)
# now that we have specified the preprocessing, let's run it on the data
# create a version of the dataset that can be iterated without labels
train_ds_with_no_labels = ds_train.map(lambda x, _: x)
feature_space_mlp.adapt(train_ds_with_no_labels) # inititalize the feature map to this data
# from keras.metrics import Precision, Recall
dict_inputs = feature_space_mlp.get_inputs() # need to use unprocessed features here, to gain access to each output
# we need to create separate lists for each branch
crossed_outputs = []
# for each crossed variable, make an embedding
for col in feature_space_mlp.crossers.keys():
x = setup_embedding_from_crossing(feature_space_mlp, col)
# save these outputs in list to concatenate later
crossed_outputs.append(x)
# now concatenate the outputs and add a fully connected layer
# wide_branch = Concatenate(name='wide_concat')(crossed_outputs)
# reset this input branch
all_deep_branch_outputs = []
# for each numeric variable, just add it in after embedding
for idx,col in enumerate(numeric_headers):
x = feature_space_mlp.preprocessors[col].output
x = tf.cast(x,float) # cast an integer as a float here
all_deep_branch_outputs.append(x)
# for each categorical variable
for col in categorical_headers:
# get the output tensor from ebedding layer
x = setup_embedding_from_categorical(feature_space_mlp, col)
# save these outputs in list to concatenate later
all_deep_branch_outputs.append(x)
# merge the deep branches together
deep_branch = Concatenate(name='embed_concat')(all_deep_branch_outputs)
deep_branch = Dense(units=34,activation='relu', name='deep1')(deep_branch)
deep_branch = Dense(units=17,activation='relu', name='deep2')(deep_branch)
deep_branch = Dense(units=8,activation='relu', name='deep3')(deep_branch) # Changed from 10 to 8 neurons
deep_branch = Dense(units=4,activation='relu', name='deep4')(deep_branch) # This is my new layer
deep_branch = Dense(units=1,activation='sigmoid', name='deep5')(deep_branch) # adding this sigmoid layer to make a complete MLP representation from the Deep side
training_model_mlp = keras.Model(inputs=dict_inputs, outputs=deep_branch)
# Source: https://datascience.stackexchange.com/questions/45165/how-to-get-accuracy-f1-precision-and-recall-for-a-keras-model
training_model_mlp.compile(
optimizer="adam", loss="binary_crossentropy", metrics=['acc',f1_m,precision_m, recall_m]
)
training_model_mlp.summary()
plot_model(
training_model_mlp, to_file='model.png', show_shapes=True, show_layer_names=True,
rankdir='LR', expand_nested=False, dpi=96
)
Model: "model_39"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
type_of_breast_surgery (InputL [(None, 1)] 0 []
ayer)
cancer_type_detailed (InputLay [(None, 1)] 0 []
er)
cellularity (InputLayer) [(None, 1)] 0 []
pam50_plus_claudin-low_subtype [(None, 1)] 0 []
(InputLayer)
er_status_measured_by_ihc (Inp [(None, 1)] 0 []
utLayer)
er_status (InputLayer) [(None, 1)] 0 []
her2_status_measured_by_snp6 ( [(None, 1)] 0 []
InputLayer)
her2_status (InputLayer) [(None, 1)] 0 []
tumor_other_histologic_subtype [(None, 1)] 0 []
(InputLayer)
inferred_menopausal_state (Inp [(None, 1)] 0 []
utLayer)
integrative_cluster (InputLaye [(None, 1)] 0 []
r)
primary_tumor_laterality (Inpu [(None, 1)] 0 []
tLayer)
oncotree_code (InputLayer) [(None, 1)] 0 []
pr_status (InputLayer) [(None, 1)] 0 []
3-gene_classifier_subtype (Inp [(None, 1)] 0 []
utLayer)
age_at_diagnosis (InputLayer) [(None, 1)] 0 []
neoplasm_histologic_grade (Inp [(None, 1)] 0 []
utLayer)
lymph_nodes_examined_positive [(None, 1)] 0 []
(InputLayer)
mutation_count (InputLayer) [(None, 1)] 0 []
nottingham_prognostic_index (I [(None, 1)] 0 []
nputLayer)
overall_survival_months (Input [(None, 1)] 0 []
Layer)
tumor_size (InputLayer) [(None, 1)] 0 []
tumor_stage (InputLayer) [(None, 1)] 0 []
chemotherapy (InputLayer) [(None, 1)] 0 []
radio_therapy (InputLayer) [(None, 1)] 0 []
hormone_therapy (InputLayer) [(None, 1)] 0 []
string_categorical_454_preproc (None, 1) 0 ['type_of_breast_surgery[0][0]']
essor (StringLookup)
string_categorical_455_preproc (None, 1) 0 ['cancer_type_detailed[0][0]']
essor (StringLookup)
string_categorical_456_preproc (None, 1) 0 ['cellularity[0][0]']
essor (StringLookup)
string_categorical_457_preproc (None, 1) 0 ['pam50_plus_claudin-low_subtype[
essor (StringLookup) 0][0]']
string_categorical_458_preproc (None, 1) 0 ['er_status_measured_by_ihc[0][0]
essor (StringLookup) ']
string_categorical_459_preproc (None, 1) 0 ['er_status[0][0]']
essor (StringLookup)
string_categorical_460_preproc (None, 1) 0 ['her2_status_measured_by_snp6[0]
essor (StringLookup) [0]']
string_categorical_461_preproc (None, 1) 0 ['her2_status[0][0]']
essor (StringLookup)
string_categorical_462_preproc (None, 1) 0 ['tumor_other_histologic_subtype[
essor (StringLookup) 0][0]']
string_categorical_463_preproc (None, 1) 0 ['inferred_menopausal_state[0][0]
essor (StringLookup) ']
string_categorical_464_preproc (None, 1) 0 ['integrative_cluster[0][0]']
essor (StringLookup)
string_categorical_465_preproc (None, 1) 0 ['primary_tumor_laterality[0][0]'
essor (StringLookup) ]
string_categorical_466_preproc (None, 1) 0 ['oncotree_code[0][0]']
essor (StringLookup)
string_categorical_467_preproc (None, 1) 0 ['pr_status[0][0]']
essor (StringLookup)
string_categorical_468_preproc (None, 1) 0 ['3-gene_classifier_subtype[0][0]
essor (StringLookup) ']
float_normalized_328_preproces (None, 1) 3 ['age_at_diagnosis[0][0]']
sor (Normalization)
float_normalized_329_preproces (None, 1) 3 ['neoplasm_histologic_grade[0][0]
sor (Normalization) ']
float_normalized_330_preproces (None, 1) 3 ['lymph_nodes_examined_positive[0
sor (Normalization) ][0]']
float_normalized_331_preproces (None, 1) 3 ['mutation_count[0][0]']
sor (Normalization)
float_normalized_332_preproces (None, 1) 3 ['nottingham_prognostic_index[0][
sor (Normalization) 0]']
float_normalized_333_preproces (None, 1) 3 ['overall_survival_months[0][0]']
sor (Normalization)
float_normalized_334_preproces (None, 1) 3 ['tumor_size[0][0]']
sor (Normalization)
float_normalized_335_preproces (None, 1) 3 ['tumor_stage[0][0]']
sor (Normalization)
float_normalized_325_preproces (None, 1) 3 ['chemotherapy[0][0]']
sor (Normalization)
float_normalized_327_preproces (None, 1) 3 ['radio_therapy[0][0]']
sor (Normalization)
float_normalized_326_preproces (None, 1) 3 ['hormone_therapy[0][0]']
sor (Normalization)
type_of_breast_surgery_embed ( (None, 1, 1) 2 ['string_categorical_454_preproce
Embedding) ssor[0][0]']
cancer_type_detailed_embed (Em (None, 1, 2) 10 ['string_categorical_455_preproce
bedding) ssor[0][0]']
cellularity_embed (Embedding) (None, 1, 1) 3 ['string_categorical_456_preproce
ssor[0][0]']
pam50_plus_claudin-low_subtype (None, 1, 2) 14 ['string_categorical_457_preproce
_embed (Embedding) ssor[0][0]']
er_status_measured_by_ihc_embe (None, 1, 1) 2 ['string_categorical_458_preproce
d (Embedding) ssor[0][0]']
er_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_459_preproce
ssor[0][0]']
her2_status_measured_by_snp6_e (None, 1, 2) 8 ['string_categorical_460_preproce
mbed (Embedding) ssor[0][0]']
her2_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_461_preproce
ssor[0][0]']
tumor_other_histologic_subtype (None, 1, 2) 14 ['string_categorical_462_preproce
_embed (Embedding) ssor[0][0]']
inferred_menopausal_state_embe (None, 1, 1) 2 ['string_categorical_463_preproce
d (Embedding) ssor[0][0]']
integrative_cluster_embed (Emb (None, 1, 3) 33 ['string_categorical_464_preproce
edding) ssor[0][0]']
primary_tumor_laterality_embed (None, 1, 1) 2 ['string_categorical_465_preproce
(Embedding) ssor[0][0]']
oncotree_code_embed (Embedding (None, 1, 2) 10 ['string_categorical_466_preproce
) ssor[0][0]']
pr_status_embed (Embedding) (None, 1, 1) 2 ['string_categorical_467_preproce
ssor[0][0]']
3-gene_classifier_subtype_embe (None, 1, 2) 8 ['string_categorical_468_preproce
d (Embedding) ssor[0][0]']
tf.cast_429 (TFOpLambda) (None, 1) 0 ['float_normalized_328_preprocess
or[0][0]']
tf.cast_430 (TFOpLambda) (None, 1) 0 ['float_normalized_329_preprocess
or[0][0]']
tf.cast_431 (TFOpLambda) (None, 1) 0 ['float_normalized_330_preprocess
or[0][0]']
tf.cast_432 (TFOpLambda) (None, 1) 0 ['float_normalized_331_preprocess
or[0][0]']
tf.cast_433 (TFOpLambda) (None, 1) 0 ['float_normalized_332_preprocess
or[0][0]']
tf.cast_434 (TFOpLambda) (None, 1) 0 ['float_normalized_333_preprocess
or[0][0]']
tf.cast_435 (TFOpLambda) (None, 1) 0 ['float_normalized_334_preprocess
or[0][0]']
tf.cast_436 (TFOpLambda) (None, 1) 0 ['float_normalized_335_preprocess
or[0][0]']
tf.cast_437 (TFOpLambda) (None, 1) 0 ['float_normalized_325_preprocess
or[0][0]']
tf.cast_438 (TFOpLambda) (None, 1) 0 ['float_normalized_327_preprocess
or[0][0]']
tf.cast_439 (TFOpLambda) (None, 1) 0 ['float_normalized_326_preprocess
or[0][0]']
flatten_689 (Flatten) (None, 1) 0 ['type_of_breast_surgery_embed[0]
[0]']
flatten_690 (Flatten) (None, 2) 0 ['cancer_type_detailed_embed[0][0
]']
flatten_691 (Flatten) (None, 1) 0 ['cellularity_embed[0][0]']
flatten_692 (Flatten) (None, 2) 0 ['pam50_plus_claudin-low_subtype_
embed[0][0]']
flatten_693 (Flatten) (None, 1) 0 ['er_status_measured_by_ihc_embed
[0][0]']
flatten_694 (Flatten) (None, 1) 0 ['er_status_embed[0][0]']
flatten_695 (Flatten) (None, 2) 0 ['her2_status_measured_by_snp6_em
bed[0][0]']
flatten_696 (Flatten) (None, 1) 0 ['her2_status_embed[0][0]']
flatten_697 (Flatten) (None, 2) 0 ['tumor_other_histologic_subtype_
embed[0][0]']
flatten_698 (Flatten) (None, 1) 0 ['inferred_menopausal_state_embed
[0][0]']
flatten_699 (Flatten) (None, 3) 0 ['integrative_cluster_embed[0][0]
']
flatten_700 (Flatten) (None, 1) 0 ['primary_tumor_laterality_embed[
0][0]']
flatten_701 (Flatten) (None, 2) 0 ['oncotree_code_embed[0][0]']
flatten_702 (Flatten) (None, 1) 0 ['pr_status_embed[0][0]']
flatten_703 (Flatten) (None, 2) 0 ['3-gene_classifier_subtype_embed
[0][0]']
embed_concat (Concatenate) (None, 34) 0 ['tf.cast_429[0][0]',
'tf.cast_430[0][0]',
'tf.cast_431[0][0]',
'tf.cast_432[0][0]',
'tf.cast_433[0][0]',
'tf.cast_434[0][0]',
'tf.cast_435[0][0]',
'tf.cast_436[0][0]',
'tf.cast_437[0][0]',
'tf.cast_438[0][0]',
'tf.cast_439[0][0]',
'flatten_689[0][0]',
'flatten_690[0][0]',
'flatten_691[0][0]',
'flatten_692[0][0]',
'flatten_693[0][0]',
'flatten_694[0][0]',
'flatten_695[0][0]',
'flatten_696[0][0]',
'flatten_697[0][0]',
'flatten_698[0][0]',
'flatten_699[0][0]',
'flatten_700[0][0]',
'flatten_701[0][0]',
'flatten_702[0][0]',
'flatten_703[0][0]']
deep1 (Dense) (None, 34) 1190 ['embed_concat[0][0]']
deep2 (Dense) (None, 17) 595 ['deep1[0][0]']
deep3 (Dense) (None, 8) 144 ['deep2[0][0]']
deep4 (Dense) (None, 4) 36 ['deep3[0][0]']
deep5 (Dense) (None, 1) 5 ['deep4[0][0]']
==================================================================================================
Total params: 2,117
Trainable params: 2,084
Non-trainable params: 33
__________________________________________________________________________________________________
# train using the already processed features
history_mlp = training_model_mlp.fit(
ds_train, epochs=35, validation_data=ds_test, verbose=2
) # changed number of epochs
Epoch 1/35 19/19 - 6s - loss: 0.6599 - acc: 0.5645 - f1_m: 0.2866 - precision_m: 0.6511 - recall_m: 0.1948 - val_loss: 0.6304 - val_acc: 0.6543 - val_f1_m: 0.5149 - val_precision_m: 0.8018 - val_recall_m: 0.4025 - 6s/epoch - 311ms/step Epoch 2/35 19/19 - 0s - loss: 0.6023 - acc: 0.7370 - f1_m: 0.6602 - precision_m: 0.8006 - recall_m: 0.5892 - val_loss: 0.5950 - val_acc: 0.7778 - val_f1_m: 0.7736 - val_precision_m: 0.7866 - val_recall_m: 0.8067 - 71ms/epoch - 4ms/step Epoch 3/35 19/19 - 0s - loss: 0.5685 - acc: 0.7742 - f1_m: 0.7608 - precision_m: 0.7758 - recall_m: 0.7939 - val_loss: 0.5600 - val_acc: 0.8210 - val_f1_m: 0.8179 - val_precision_m: 0.8201 - val_recall_m: 0.8700 - 70ms/epoch - 4ms/step Epoch 4/35 19/19 - 0s - loss: 0.5355 - acc: 0.7816 - f1_m: 0.7721 - precision_m: 0.7862 - recall_m: 0.8088 - val_loss: 0.5232 - val_acc: 0.8086 - val_f1_m: 0.8076 - val_precision_m: 0.8072 - val_recall_m: 0.8600 - 71ms/epoch - 4ms/step Epoch 5/35 19/19 - 0s - loss: 0.5016 - acc: 0.8040 - f1_m: 0.7949 - precision_m: 0.8070 - recall_m: 0.8272 - val_loss: 0.4926 - val_acc: 0.7963 - val_f1_m: 0.7979 - val_precision_m: 0.7853 - val_recall_m: 0.8600 - 70ms/epoch - 4ms/step Epoch 6/35 19/19 - 0s - loss: 0.4780 - acc: 0.8065 - f1_m: 0.7978 - precision_m: 0.8048 - recall_m: 0.8335 - val_loss: 0.4691 - val_acc: 0.8025 - val_f1_m: 0.8030 - val_precision_m: 0.7958 - val_recall_m: 0.8600 - 78ms/epoch - 4ms/step Epoch 7/35 19/19 - 0s - loss: 0.4594 - acc: 0.8164 - f1_m: 0.8076 - precision_m: 0.8070 - recall_m: 0.8477 - val_loss: 0.4537 - val_acc: 0.7963 - val_f1_m: 0.7981 - val_precision_m: 0.7961 - val_recall_m: 0.8525 - 80ms/epoch - 4ms/step Epoch 8/35 19/19 - 0s - loss: 0.4458 - acc: 0.8201 - f1_m: 0.8107 - precision_m: 0.8083 - recall_m: 0.8510 - val_loss: 0.4444 - val_acc: 0.7963 - val_f1_m: 0.7981 - val_precision_m: 0.7961 - val_recall_m: 0.8525 - 92ms/epoch - 5ms/step Epoch 9/35 19/19 - 0s - loss: 0.4366 - acc: 0.8213 - f1_m: 0.8124 - precision_m: 0.8086 - recall_m: 0.8538 - val_loss: 0.4384 - val_acc: 0.7963 - val_f1_m: 0.7981 - val_precision_m: 0.7961 - val_recall_m: 0.8525 - 64ms/epoch - 3ms/step Epoch 10/35 19/19 - 0s - loss: 0.4282 - acc: 0.8226 - f1_m: 0.8136 - precision_m: 0.8107 - recall_m: 0.8542 - val_loss: 0.4344 - val_acc: 0.7901 - val_f1_m: 0.7884 - val_precision_m: 0.7978 - val_recall_m: 0.8300 - 75ms/epoch - 4ms/step Epoch 11/35 19/19 - 0s - loss: 0.4210 - acc: 0.8300 - f1_m: 0.8192 - precision_m: 0.8189 - recall_m: 0.8542 - val_loss: 0.4320 - val_acc: 0.7840 - val_f1_m: 0.7824 - val_precision_m: 0.7973 - val_recall_m: 0.8200 - 73ms/epoch - 4ms/step Epoch 12/35 19/19 - 0s - loss: 0.4135 - acc: 0.8313 - f1_m: 0.8193 - precision_m: 0.8207 - recall_m: 0.8512 - val_loss: 0.4312 - val_acc: 0.7901 - val_f1_m: 0.7897 - val_precision_m: 0.7993 - val_recall_m: 0.8325 - 75ms/epoch - 4ms/step Epoch 13/35 19/19 - 0s - loss: 0.4080 - acc: 0.8325 - f1_m: 0.8208 - precision_m: 0.8202 - recall_m: 0.8543 - val_loss: 0.4302 - val_acc: 0.7901 - val_f1_m: 0.7897 - val_precision_m: 0.7993 - val_recall_m: 0.8325 - 73ms/epoch - 4ms/step Epoch 14/35 19/19 - 0s - loss: 0.4024 - acc: 0.8375 - f1_m: 0.8259 - precision_m: 0.8238 - recall_m: 0.8592 - val_loss: 0.4299 - val_acc: 0.7901 - val_f1_m: 0.7897 - val_precision_m: 0.7993 - val_recall_m: 0.8325 - 71ms/epoch - 4ms/step Epoch 15/35 19/19 - 0s - loss: 0.3973 - acc: 0.8412 - f1_m: 0.8295 - precision_m: 0.8260 - recall_m: 0.8635 - val_loss: 0.4300 - val_acc: 0.7901 - val_f1_m: 0.7897 - val_precision_m: 0.7993 - val_recall_m: 0.8325 - 71ms/epoch - 4ms/step Epoch 16/35 19/19 - 0s - loss: 0.3928 - acc: 0.8437 - f1_m: 0.8324 - precision_m: 0.8288 - recall_m: 0.8666 - val_loss: 0.4303 - val_acc: 0.7963 - val_f1_m: 0.7945 - val_precision_m: 0.8047 - val_recall_m: 0.8325 - 70ms/epoch - 4ms/step Epoch 17/35 19/19 - 0s - loss: 0.3873 - acc: 0.8462 - f1_m: 0.8351 - precision_m: 0.8303 - recall_m: 0.8672 - val_loss: 0.4300 - val_acc: 0.8025 - val_f1_m: 0.7995 - val_precision_m: 0.8106 - val_recall_m: 0.8325 - 71ms/epoch - 4ms/step Epoch 18/35 19/19 - 0s - loss: 0.3825 - acc: 0.8499 - f1_m: 0.8391 - precision_m: 0.8347 - recall_m: 0.8703 - val_loss: 0.4296 - val_acc: 0.8086 - val_f1_m: 0.8056 - val_precision_m: 0.8112 - val_recall_m: 0.8425 - 72ms/epoch - 4ms/step Epoch 19/35 19/19 - 0s - loss: 0.3777 - acc: 0.8548 - f1_m: 0.8433 - precision_m: 0.8394 - recall_m: 0.8722 - val_loss: 0.4293 - val_acc: 0.8086 - val_f1_m: 0.8056 - val_precision_m: 0.8112 - val_recall_m: 0.8425 - 69ms/epoch - 4ms/step Epoch 20/35 19/19 - 0s - loss: 0.3731 - acc: 0.8586 - f1_m: 0.8463 - precision_m: 0.8424 - recall_m: 0.8735 - val_loss: 0.4293 - val_acc: 0.8086 - val_f1_m: 0.8056 - val_precision_m: 0.8112 - val_recall_m: 0.8425 - 69ms/epoch - 4ms/step Epoch 21/35 19/19 - 0s - loss: 0.3684 - acc: 0.8598 - f1_m: 0.8482 - precision_m: 0.8451 - recall_m: 0.8748 - val_loss: 0.4287 - val_acc: 0.8086 - val_f1_m: 0.8056 - val_precision_m: 0.8112 - val_recall_m: 0.8425 - 70ms/epoch - 4ms/step Epoch 22/35 19/19 - 0s - loss: 0.3648 - acc: 0.8623 - f1_m: 0.8505 - precision_m: 0.8463 - recall_m: 0.8772 - val_loss: 0.4294 - val_acc: 0.8148 - val_f1_m: 0.8113 - val_precision_m: 0.8117 - val_recall_m: 0.8525 - 70ms/epoch - 4ms/step Epoch 23/35 19/19 - 0s - loss: 0.3593 - acc: 0.8623 - f1_m: 0.8512 - precision_m: 0.8448 - recall_m: 0.8800 - val_loss: 0.4301 - val_acc: 0.8148 - val_f1_m: 0.8113 - val_precision_m: 0.8117 - val_recall_m: 0.8525 - 67ms/epoch - 4ms/step Epoch 24/35 19/19 - 0s - loss: 0.3556 - acc: 0.8648 - f1_m: 0.8534 - precision_m: 0.8464 - recall_m: 0.8818 - val_loss: 0.4301 - val_acc: 0.8148 - val_f1_m: 0.8113 - val_precision_m: 0.8117 - val_recall_m: 0.8525 - 69ms/epoch - 4ms/step Epoch 25/35 19/19 - 0s - loss: 0.3511 - acc: 0.8648 - f1_m: 0.8536 - precision_m: 0.8464 - recall_m: 0.8818 - val_loss: 0.4306 - val_acc: 0.8025 - val_f1_m: 0.8011 - val_precision_m: 0.7956 - val_recall_m: 0.8525 - 74ms/epoch - 4ms/step Epoch 26/35 19/19 - 0s - loss: 0.3479 - acc: 0.8672 - f1_m: 0.8563 - precision_m: 0.8473 - recall_m: 0.8865 - val_loss: 0.4293 - val_acc: 0.8025 - val_f1_m: 0.8011 - val_precision_m: 0.7956 - val_recall_m: 0.8525 - 73ms/epoch - 4ms/step Epoch 27/35 19/19 - 0s - loss: 0.3427 - acc: 0.8672 - f1_m: 0.8561 - precision_m: 0.8468 - recall_m: 0.8865 - val_loss: 0.4292 - val_acc: 0.7963 - val_f1_m: 0.7962 - val_precision_m: 0.7864 - val_recall_m: 0.8525 - 73ms/epoch - 4ms/step Epoch 28/35 19/19 - 0s - loss: 0.3387 - acc: 0.8734 - f1_m: 0.8635 - precision_m: 0.8499 - recall_m: 0.8983 - val_loss: 0.4298 - val_acc: 0.7901 - val_f1_m: 0.7914 - val_precision_m: 0.7810 - val_recall_m: 0.8525 - 70ms/epoch - 4ms/step Epoch 29/35 19/19 - 0s - loss: 0.3345 - acc: 0.8734 - f1_m: 0.8635 - precision_m: 0.8499 - recall_m: 0.8983 - val_loss: 0.4298 - val_acc: 0.7901 - val_f1_m: 0.7914 - val_precision_m: 0.7810 - val_recall_m: 0.8525 - 71ms/epoch - 4ms/step Epoch 30/35 19/19 - 0s - loss: 0.3312 - acc: 0.8734 - f1_m: 0.8635 - precision_m: 0.8499 - recall_m: 0.8983 - val_loss: 0.4303 - val_acc: 0.7901 - val_f1_m: 0.7914 - val_precision_m: 0.7810 - val_recall_m: 0.8525 - 71ms/epoch - 4ms/step Epoch 31/35 19/19 - 0s - loss: 0.3267 - acc: 0.8784 - f1_m: 0.8682 - precision_m: 0.8585 - recall_m: 0.8983 - val_loss: 0.4310 - val_acc: 0.7901 - val_f1_m: 0.7914 - val_precision_m: 0.7810 - val_recall_m: 0.8525 - 78ms/epoch - 4ms/step Epoch 32/35 19/19 - 0s - loss: 0.3223 - acc: 0.8797 - f1_m: 0.8707 - precision_m: 0.8572 - recall_m: 0.9043 - val_loss: 0.4319 - val_acc: 0.7963 - val_f1_m: 0.7982 - val_precision_m: 0.7823 - val_recall_m: 0.8625 - 71ms/epoch - 4ms/step Epoch 33/35 19/19 - 0s - loss: 0.3179 - acc: 0.8834 - f1_m: 0.8746 - precision_m: 0.8619 - recall_m: 0.9074 - val_loss: 0.4326 - val_acc: 0.7963 - val_f1_m: 0.7982 - val_precision_m: 0.7823 - val_recall_m: 0.8625 - 76ms/epoch - 4ms/step Epoch 34/35 19/19 - 0s - loss: 0.3142 - acc: 0.8859 - f1_m: 0.8769 - precision_m: 0.8652 - recall_m: 0.9074 - val_loss: 0.4327 - val_acc: 0.8025 - val_f1_m: 0.8030 - val_precision_m: 0.7877 - val_recall_m: 0.8625 - 76ms/epoch - 4ms/step Epoch 35/35 19/19 - 0s - loss: 0.3087 - acc: 0.8883 - f1_m: 0.8786 - precision_m: 0.8669 - recall_m: 0.9086 - val_loss: 0.4340 - val_acc: 0.8025 - val_f1_m: 0.8032 - val_precision_m: 0.7915 - val_recall_m: 0.8625 - 78ms/epoch - 4ms/step
from matplotlib import pyplot as plt
%matplotlib inline
plt.figure(figsize=(10,4))
plt.subplot(2,2,1)
plt.plot(history_mlp.history['f1_m'])
plt.ylabel('F1 Score %')
plt.title('Training')
plt.subplot(2,2,2)
plt.plot(history_mlp.history['val_f1_m'])
plt.title('Validation')
plt.subplot(2,2,3)
plt.plot(history_mlp.history['loss'])
plt.ylabel('Training Loss')
plt.xlabel('epochs')
plt.subplot(2,2,4)
plt.plot(history_mlp.history['val_loss'])
plt.xlabel('epochs')
Text(0.5, 0, 'epochs')
# Vizualize some metrics associated with this model
# Source: Modified from in-class lecture
# now lets see how well the model performed
yhat_proba_mlp = training_model_mlp.predict(ds_test) # sigmoid output probabilities
# use squeeze to get rid of any extra dimensions
yhat_4 = np.round(yhat_proba_4.squeeze()) # round to get binary class
conf_mat_4 = mt.confusion_matrix(y_test, yhat_4)
print(conf_mat_4)
print(mt.classification_report(y_test,yhat_4))
# Source: Albon, Chris. Machine Learning with Python Cookbook. O'Reilly Media, 20180309. VitalBook file.
# Create pandas dataframe
conf_df_4 = pd.DataFrame(conf_mat_4, index=bc_df.death_from_cancer.unique(), columns=bc_df.death_from_cancer.unique())
# Create heatmap
sns.heatmap(conf_df_4, annot=True, cbar=None, cmap="Blues")
plt.title("Confusion Matrix"), plt.tight_layout()
plt.ylabel("True Class"), plt.xlabel("Predicted Class")
plt.show()
4/4 [==============================] - 0s 2ms/step
[[64 16]
[17 65]]
precision recall f1-score support
0 0.79 0.80 0.80 80
1 0.80 0.79 0.80 82
accuracy 0.80 162
macro avg 0.80 0.80 0.80 162
weighted avg 0.80 0.80 0.80 162
from scipy.stats import t
# Get the histories of val_f1 scores from my two models for comparison
f1_score_model_2 = history_2.history['val_f1_m']
f1_score_model_4 = history_mlp.history['val_f1_m']
# get error rates for both model's f1 scores
model_2_err = [1 - f1 for f1 in f1_score_model_2]
model_4_err = [1 - f1 for f1 in f1_score_model_4]
d = []
for err in range(len(model_2_err)):
d.append(model_2_err[err] - model_4_err[err])
dbar = sum(d) / len(d)
stdtot = np.std(d)
epochs = 12
confidence_level = 0.95
degrees_of_freedom = epochs
# Calculate the critical value, t
t = t.ppf((1 + confidence_level) / 2, degrees_of_freedom)
# print(f'The error of the three models is\n', acc1.mean(), '\n', acc2.mean(), '\n', acc3.mean())
print('Range of:', dbar-t*stdtot,dbar+t*stdtot, 'between model 2 and the mlp model')
Range of: -0.08233255758287983 0.07009268019996559 between model 2 and the mlp model
Everything here compares well with previous results. No significant change to the confusion matrix, and no significant change to my statistical analysis. This result indicates the models are not statistically different from one another.
Now let's check out the ROC curve for these two models.
# Load libraries
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.model_selection import train_test_split
# # Create feature matrix and target vector
# features, target = make_classification(n_samples=10000,
# n_features=10,
# n_classes=2,
# n_informative=3,
# random_state=3)
# # Split into training and test sets
# features_train, features_test, target_train, target_test = train_test_split(
# features, target, test_size=0.1, random_state=1)
# # Create classifier
# logit = LogisticRegression()
# # Train model
# logit.fit(features_train, target_train)
# # Get predicted probabilities
# target_probabilities = logit.predict_proba(features_test)[:,1]
# Create true and false positive rates
false_positive_rate_mlp, true_positive_rate_mlp, threshold = roc_curve(y_test,
yhat_proba_mlp)
false_positive_rate_2, true_positive_rate_2, threshold = roc_curve(y_test,
yhat_proba_2)
# Plot ROC curve
plt.title("Receiver Operating Characteristic")
plt.plot(false_positive_rate_mlp, true_positive_rate_mlp, label='MLP curve')
plt.plot(false_positive_rate_2, true_positive_rate_2, label='Model 2')
plt.plot([0, 1], ls="--")
plt.plot([0, 0], [1, 0] , c=".7"), plt.plot([1, 1] , c=".7")
plt.ylabel("True Positive Rate")
plt.xlabel("False Positive Rate")
plt.legend()
plt.show()
The resulting ROC curve is increadibly close which aligns with the analysis to show that there is little to no statistical difference in these models. If I had to pick, my MLP only model slightly outperforms as the area under the curve is larger, though it appears very slight.
Takeaways¶
This was an interesting dataset to analyze and try to understnad how well we could predict the result. At roughly an 80 to 84% F1 score I'm not confident I would deploy this model for use. While it could be a guiding point in discussions on patient outcome, it contains enough error that I would be hesitant to rely on it as a predictor of outcomes. A further exploration of some of the genomic features may be warranted to see if they would lend additional insight into this analysis.